{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.7.3 64-bit ('.venv': venv)",
      "language": "python",
      "name": "python37364bitvenvvenv7b03464ca52642a09f5ea3f892125398"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3-final"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "logistic-regression.ipynb",
      "provenance": [],
      "include_colab_link": true
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/DeepSE/deeplearning-models/blob/master/pytorch_ipynb/basic-ml/logistic-regression.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_zp9vuP-AsLd",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9o7GULZOBZxX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q torch\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q tensorwatch\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "QJQPAzOlAsLf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "ada063d9-82d5-4bc8-f2ca-55bf6f8e203f"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TUtJ6qP9AsLm",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ohPY1L9DAsLn",
        "colab_type": "text"
      },
      "source": [
        "# Model Zoo -- Logistic Regression"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yGS_aC2cAsLo",
        "colab_type": "text"
      },
      "source": [
        "Implementation of *classic* logistic regression for binary class labels."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qUcQdQkXAsLp",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lIf6X212AsLq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from io import BytesIO\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HbUOSDwcAsLv",
        "colab_type": "text"
      },
      "source": [
        "## Preparing a toy dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "j2waW6yPAsLw",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 184
        },
        "outputId": "ef7fb51a-6a64-49b0-ec4a-cbc341ebbf04"
      },
      "source": [
        "##########################\n",
        "### DATASET\n",
        "##########################\n",
        "\n",
        "ds = np.lib.DataSource()\n",
        "fp = ds.open('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')\n",
        "\n",
        "x = np.genfromtxt(BytesIO(fp.read().encode()), delimiter=',', usecols=range(2), max_rows=100)\n",
        "y = np.zeros(100)\n",
        "y[50:] = 1\n",
        "\n",
        "np.random.seed(1)\n",
        "idx = np.arange(y.shape[0])\n",
        "np.random.shuffle(idx)\n",
        "X_test, y_test = x[idx[:25]], y[idx[:25]]\n",
        "X_train, y_train = x[idx[25:]], y[idx[25:]]\n",
        "mu, std = np.mean(X_train, axis=0), np.std(X_train, axis=0)\n",
        "X_train, X_test = (X_train - mu) / std, (X_test - mu) / std\n",
        "\n",
        "fig, ax = plt.subplots(1, 2, figsize=(7, 2.5))\n",
        "ax[0].scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1])\n",
        "ax[0].scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1])\n",
        "ax[1].scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1])\n",
        "ax[1].scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1])\n",
        "plt.show()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaoAAACnCAYAAABAZhicAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWmElEQVR4nO3df4xU13UH8O/ZydpdydGuHCM33h8G2S6RZUgRK2OLP6xCI6gLDnEaEixZoUVaRUqUpri00FiUIFemQsUNqv9BtUUquySr2iYOdkRc7AbFCoRd4yy2MZEbF9hNpKxlQWIVlWU5/ePNsju7772ZO+++ufe+9/1IaJnH7HvH43fnztx77rmiqiAiIvJVm+sAiIiI0rCjIiIir7GjIiIir7GjIiIir7GjIiIir7GjIiIir33MxUVvuukmnT9/votLE1kxPDz8garOcx3HFLYpKoKkdpW5oxKR3wNwFMD11fP9h6r+fdrvzJ8/H0NDQ1kvTeSMiJx1HcNMbFNUBEntysY3qv8DsEJVPxKRdgA/EZEfquoxC+cmIqKSyzxHpZGPqg/bq39Y7sIHI4PAE3cBO7qinyODriMiyg/v98KyMkclIhUAwwBuB/Ckqh63cV7KYGQQ+MHXgYlL0eOL56PHALB4vbu4iPLA+73QrGT9qeqkqv4hgB4Ad4vIXbOfIyIDIjIkIkPj4+M2LktpjuycbrRTJi5Fx4mKhvd7oVlNT1fVCwBeA7A65t/2qWq/qvbPm+dNslRxXRw1O04UMt7vhZa5oxKReSLSVf17B4DPAHg363kpo84es+NEIeP9Xmg2vlF9EsBrIjIC4ASAV1T1kIXzUhYrtwPtHbXH2jui40RFw/u90DInU6jqCIAlFmIhm6YmkI/sjIY/OnuiRsuJZXJtZND+fcn7vdCcVKagFlm8ng2V/JJndh7v98JirT8iah1m51ET2FERUeswO4+awI6KiFqH2XnUBHZURNQ6zM6jJrCjIqLWWbweWLsX6OwFINHPtXuZBEGpmPVH5BkR6QXwbwBuRlTgeZ+qftttVBYxO48MsaMi8s8VAI+o6hsi8nEAwyLyiqq+4zowIhc49EfkGVX9taq+Uf377wCcBtDtNioid9hREXlMROYjqvzCrXOotNhREXlKRG4A8ByAb6jqb2P+nVvnUCmwoyLykIi0I+qknlXV5+Oew61zqCzYURF5RkQEwFMATqvqHtfxELnGjorIP8sBPAxghYi8Wf1zv+ugiFzJnJ5e+DUfRZTHNgtkjar+BIC4joPIFzbWUXHNR0jy3GaBiCgHmYf+uOYjMNxmgYgCY3WOims+AsBtFogoMNZKKDWy5gPAAAD09fXZuiyZ6uyJhvvijhPZcGgzMLwf0ElAKsDSjcAaJi9S86x8o+KaD8sObQa+dSOwozP6eWizvXNzmwXK06HNwNBTUScFRD+HnrJ7D1PpZO6ouObDsrwbOrdZoDwN7zc7TtQAG0N/U2s+TonIm9Vjf6eqL1s4d/mkNXRbwyfcZoHyMvUBq9HjRA3I3FFxzYdlbOgUMqnE36tSaX0sVBisTOGbpAbNhk4hWLrR7DhRA9hR2TQyCDxxF7CjK/o5Mmh+DpsN3UY8RCbW7AH6N01/sJJK9NjGsDXv59LiDr+22Kr4MNWgs6b3sgIFubJmj/10dN7PpcaOypa0ig+mDclGQ7cZD5FrvJ9LjUN/tvhW8cG3eIiy4P1cauyobEmq7OCq4oNv8RBlwfu51NhR2bJyO9DWXnusrT294kOek8OsQEFFwvu51DhHZZNI+uOZ8p4cnjoH952iIuD9XGrsqGw5shOYvFx7bPJy8mRvKyaHWYGCioT3c2lx6M8W08leTg4TETWEHZUtppO9nBwmImpIMTuqvFewx53fdLKXk8OUQkSeFpHfiMhbrmMhcq14HdVUksLF8wB0OknBVmeVdH7AbPsMbrdB6fYDWO06CCIfFC+ZIu8khbTz/9VbZtfg5DAlUNWjIjLfdRxNGRl0k53n6rqUO1s7/PozTJF3kgKTIMgTIjIgIkMiMjQ+Pu46nEjeIxq+XZdawtbQ3374MkyRd5ICkyDIE6q6T1X7VbV/3rx5rsOJpI04FPG61BJWOipVPQrgQxvnyizvJIWV24G2WXtDtVWi40lJHNyegMrC1YgDRzoKrXhzVHmvYD93DLg6awfTq5PAyWeA0Z/NrTRx7hjw83/n9gRUDp091eG3mONFvC61RMuy/lo6nr54fZTYsOOCeYJDPcP744+//+P4oYfh/RySIGMicgDATwEsFJFREdnkOqaGuFp2weUehdayjsrL8fRm6GT95zTyfA5JUApV3aCqn1TVdlXtUdWnXMfUkMXrgU8/VLvD76cfsvNhMW0Incs9Cq14Q395k4pZZ5X0fA5JUBGNDEZD3VP3vE5Gj/vuydZpNFLEmcs9CstWenqYwxTNWLox/viC++KHHpZu5JAElUde2XfM6is1K9+oVHWDjfNYY7rw79DmaC5JJ6NvQEs3Jm8Fv2YP8D+vAx+8O33spk8BX34x/bozz19vKMSzhYsHT45h9+Ez+NWFS7ilqwNbVi3EuiXdzuIhj+WVfcesvlIr3tCf6T5PhzYDQzOG/3Vy+nFcZ3Voc20nBUSPD22Onj/7GqZDIXnvU2Xo4MkxbHv+FC5NRPGPXbiEbc+fAgB2VjRXXtl3zOorteLV+jMdIkjK4rN13DQez4Y4dh8+c62TmnJpYhK7D59xEg95Lq/sO2b1lVrxOirTIYKkxAhbxwPfp+pXFy4ZHaeSyyv7jll9pVa8oT/TIYKkrDypzD3WzPNN4/FsiOOWrg6MxXRKt3R1xDybCPll3zGrD0A554yL943KdIggKYuvmeMF3Kdqy6qF6Giv7YQ72ivYsmqhk3iIymxqznjswiUopueMD54ccx1arorXUZkOEazZA/Rvql2g2L8pOeuv7565356mHhdwn6p1S7rx+IOL0N3VAQHQ3dWBxx9cVPhPcEQ+Kuucsahqyy/a39+vQ0NDLb+uFU/cFT80l7iwtzcq40SFIiLDqtrvOo4pQbcpatiCrS8h7h1bALy/609bHY51Se2qeN+o8maalMF1HkRkSdLccNHnjNlRmUpLyjB5PhGRobLOGYeR9WdaqSHp+TYqPqzcDnz/q8Dk5eljleuAJQ/XbucBBLfOo4zZRLMV7jXwrMoJZTN1L6bdo4W7hxFCR2VaqSHp+Tb3hZo9r6caJVn03RPsmwIrUBTwNfCsygnZsW5Jd+L9WLh7uMr/oT9blR1s7Qt1ZCdwdaL22NWJ6Hie+2DlrKzZRDMV7jXwrMoJ5a9w93CV/x2VrQoOtpIdPKscYQsrUBTwNSjovUrJCncPV/nfUaVVcDA5bivZwfS6gShrNtFMhXsNCnqvUrLC3cNVtvajWi0iZ0TkPRHZauOc1zRT2aFyXe2xynXp+0Il7Rz6nQeAHZ3Tf77zgHeVI2wJPZvo4MkxLN/1KhZsfQnLd73a1Er90F+DOQp6r1Kywt3DVZmTKUSkAuBJAJ8BMArghIi8qKrvZD03gOl5HpMkBZNkByB+wvnoP83dzuP9H0c/1+4NNmkiSSPZRL6yNYHs02sgIqsBfBtABcC/quou45M003YoaHnew81mE9rIQsxcmUJE7gWwQ1VXVR9vAwBVfTzpd3JdRZ9UOSKpQkTS89PsuNhcbJSL5btejS2c293Vgde3rsjlmnlWpqh++PsFZnz4A7Ah7cMfK1NQnmZ/GASib2r1yqmZ/l6elSm6Acx8px+tHnMjkO0zyJ4CTiDfDeA9Vf2lql4G8F0An3UcE5VYs9mEtrIQW5ZMISIDIjIkIkPj4+P5XchW8gUFo4ATyH59+KPSa/bDoK0PkTYW/I4B6J3xuKd6rIaq7gOwD4iGKWLPZKOixMrttXNOQP3kixe+Upu+LhXgE3fMnaMCgAX3ebfa/9GDp3Dg+HlMqqIigg3LevHYukXGY8N5P9+W2df9o0/Nw/d+dh4TV6dvq/Y2CX4CuR4RGQAwAAB9fX2Oo6Eia3ZfOlv72dnoqE4AuENEFiDqoL4E4CHjs9iqKGE6gXzu2Nw1VjoJfPxm4IMzQE2tYgE+cbtXq/0fPXgKzxw7d+3xpCqeOXYO749/hDfOXWw4wcA0IcHVCvi4637vRNRJ15DcQmgFex/+iCzYsmph7FxTvQ+Dzf7ebFa2+RCR+wH8M6IMpadV9R/Snh878etq+4xv3Zi8GDiOZ9t53Lbt5blv0imSEgxMExJcJDCkXTdOwMkUH0OUTLESUQd1AsBDqvp20u+UIZmiiDXsQtKKrL+kdmWl1p+qvgzg5UwncbV9hkknlfZ8R0kZJp0UYD5mbOu4LSbnDzWZQlWviMjXABzG9Ie/xE6qDIpawy4kaTUG8/i9mfypTOFq+4yk85s+31FSRkXMxrhMEw9sHbfF5PwBJ1NAVV9W1T9Q1dvqjVCUQVFr2FFj/OmoklbRL90ItLXXHm9rt7e6funG+OML7kuOx6PV/huW9cYeX37bjWhvq+3E0hIMTFe0b1m10Oj8aUyqSmxZtRDtldrrVtpkTixFWI1P0wq4BIEM+NNRLV4fVXzo7AUg0c+1e6NqErO/NRh+i0i1Zg/Qv2n6m5JUosdffjE+njV74o87yvrrv/VGVNrmvnEvmHfD3ISClJdt3ZJuPP7gInR3dUAQze/UW8xncv4kU0M6YxcuQTE9pJNaAmnWaGcbgC/e3WsWOwWlgEsQyICVZApTRhO/ppUmSiYpuaAiEjt/ZSvBwFYyRShJHLPlmUzRjKInUzRbGYHCkmsyRa5YUSJV0tBHUpKFraESW0MxoSRxkFs+1WFsVJYsRWY41vK/o+rsSfhGxYoSQPKCuqRvVLaGSmwt5DM9j63rUnhsZI+1SpYsRWY4zuXPHFUSblWQKikJYsOy3ti5K1sJBraSKdKSOOKSLEyTPmxs/0FkKkuWIjMc5/K/o0pKsuBWBQCSkyAAYPJq7TeqyauKobMf2ru4hWSKtPjjkiwANJz00VSiBpEFWYaoObw9l/9Df0DUKbFjShQ3JPLI4M9jn3vg+Hk8tm5R5mvuPnwGE5O1HeHEpGL34TPGwxNx8S/f9Wrip8rXt65o6Bppn0zLOoRCrZFliJrD23P5/42KmpKUTGFaySJJ3p/6bJyfn0zJlSw77RZ1l94s2FEVVFLFCtNKFknyXtdi4/xce0OuNLUu0cLvFlUYQ39kbMOy3pqq6jOP22CrKnKe5887RqI09bIU01LQQ8pwbAV2VC2Q95qIuPNPzUPlvU9VXv9dNs4f4tobKgemoJvxvzJF4PJeUW96/ryfXxasTEFZ+FJhxTdJ7SrTHJWIfEFE3haRqyLiTaP1Sd5rIkzPn/fziag+JvqYyZpM8RaABwEctRBLIfmWHceSRUTuMdHHTKaOSlVPqyo/WqfwLTsulH2niIqMKehmmJ6eM5s3pGlJIRsliNigqOzyKMPFFHQzdZMpROQ/Afx+zD99U1W/X33OfwH4a1VNnM0VkQEAAwDQ19e39OzZs83GHBwbWX9pSQ3A3Mw2AEbPbzbrr6yYTFEOTCZqraR2ZSXrr5GOaiY2KnOh7ttUVOyoyoHtqLVyyfqj1mESRDkwk9YvbEd+yJqe/jkRGQVwL4CXROSwnbBoNiZBlAYzaT3CduSHrFl/L6hqj6per6o3q+oqW4GFKM+9j7asWoj2yqz9nyrJ+z/5mATBvaHqYyatX3xsR2XEoT9LWrL30ezpxJTpRd+yirg3lH0iMiAiQyIyND4+7jqcQlq3pBufX9p9rZhzRQSfX8o6fK3GWn+W5L330e7DZzAxayPEiavp+z/5VNiSe0NNaySTthGqug/APiBKprAUHs1w8OQYnhseu7Y9zqQqnhseQ/+tN5buvnWJHZUlvlWg8E3o8dukqn/sOgZqDD9g+YFDf5b4VoHCN6HHT+XED1h+YEeVwmTyP+9J19AndUOPv1WYSesXfsDyAzuqBKaT/3knL/iWHGEq9PhbhZm0fuEHLD9wP6oEXJFOaViZIixZyoCxhFjrJLUrJlMk4Ng0UTFk3U3Xp+zZsuLQXwKOTRMVAzf/DF/Y36hGBoEjO4GLo0BnD7ByO7B4vZVTb1m1MLZqsuvKDiEPQYQeP4WJoyPhC7ejGhkEfvB1YKJ6s108Hz0GrHRWU2+gvryxZh2+cC30+Clct3R1xM43c3QkHOF2VEd2TndSUyYuRcctfavyaWw69IWHocdP4fJxdITMhNtRXRw1Ox640IcvQo+f/PfowVM4cPw8JlVREcGGZb14bN0i70ZHyFy4HVVnTzTcF3e8gEIfvgg9fvLbowdP4Zlj5649nlS99niqs2LHFK6s+1HtFpF3RWRERF4QkS5bgdW1cjvQPutNrr0jOl5AIS08jKvoEVL8FJ4Dx2M+tKYcp7BkTU9/BcBdqroYwC8AbMseUoMWrwfW7gU6ewFI9HPtXmvzU74JpbJDUkUPAEHET2GaTChckHScwpJp6E9VfzTj4TEAf5YtHEOL1xe2Y4oTwvBFWtLE61tXeB8/hakiEtspTe0jRWGzueD3LwD80OL5KEBMmiAXNizrNTpOYan7jaqRTd5E5JsArgB4NuU8AwAGAKCvr6+pYMl/TJooB98Wbz+2bhEAxGb9UfjqdlT1NnkTkY0A1gBYqSkVbrkbaTlwzUrx+bp4+7F1i9gxFVTWrL/VAP4GwAOq+r92QqKQhZL0Qc1j7TxqtazrqP4FwPUAXpFo0vKYqn4lc1QUtBCSPqh5nIekVsua9Xe7rUCIKAych6RW4zYfRGSEi7ep1cItoURUQCKyG8BaAJcB/DeAP1fVC26jqsXaedRq7Kia4FtqLhXKKwC2qeoVEflHRNVe/tZxTHNwHrI+vk/Yw47KkK+puVQMzqu9kBV8n7CLc1SGmJpLLcRqL4Hi+4Rd/EZliKm5lBWrvRQf3yfsYkdliKm5lBWrvRQf3yfs4tCfIabmUp5Y7aUY+D5hF79RGWJqLuWM1V4KgO8TdrGjagJTcykvrPZSHHyfsIdDf0RE5DVJmavN76Ii4wDOtvzC8W4C8IHrIBL4HBvgd3x5x3arqs7L8fxGmmhTPv+/awTjdyuv+GPblZOOyiciMqSq/a7jiONzbIDf8fkcmw9Cf30Yv1utjp9Df0RE5DV2VERE5DV2VNUFk57yOTbA7/h8js0Hob8+jN+tlsZf+jkqIiLyG79RERGR19hRARCRL4jI2yJyVUS8yMQRkdUickZE3hORra7jmUlEnhaR34jIW65jmUlEekXkNRF5p/r/8y9dx+QzH+/7RvjcNurxte00ylUbY0cVeQvAgwCOug4EAESkAuBJAH8C4E4AG0TkTrdR1dgPYLXrIGJcAfCIqt4J4B4AX/XsdfONV/d9IwJoG/Xsh59tp1FO2hg7KgCqelpVfdoo5m4A76nqL1X1MoDvAvis45iuUdWjAD50HcdsqvprVX2j+vffATgNgDVsEnh43zfC67ZRj69tp1Gu2hg7Kj91Azg/4/Eo+IZrRETmA1gC4LjbSMgytg1PtLKNlaYobSOb1VExiMgNAJ4D8A1V/a3reFzifU95aHUbK01HVW+zOs+MAeid8bineozqEJF2RA3oWVV93nU8rgV23zeCbcMxF22MQ39+OgHgDhFZICLXAfgSgBcdx+Q9iTZwegrAaVXd4zoeygXbhkOu2hg7KgAi8jkRGQVwL4CXROSwy3hU9QqArwE4jGiyclBV33YZ00wicgDATwEsFJFREdnkOqaq5QAeBrBCRN6s/rnfdVC+8u2+b4TvbaMej9tOo5y0MVamICIir/EbFREReY0dFREReY0dFREReY0dFREReY0dFREReY0dFREReY0dFREReY0dFRERee3/AU5WwQ/TxKn1AAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 504x180 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qZis59tPAsL1",
        "colab_type": "text"
      },
      "source": [
        "## Low-level implementation with manual gradients"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9XKUolctAsL2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "def custom_where(cond, x_1, x_2):\n",
        "    cond = cond*1 # Bool to int\n",
        "    return (cond * x_1) + ((1-cond) * x_2)\n",
        "\n",
        "\n",
        "class LogisticRegression1():\n",
        "    def __init__(self, num_features):\n",
        "        self.num_features = num_features\n",
        "        self.weights = torch.zeros(num_features, 1, \n",
        "                                   dtype=torch.float32, device=device)\n",
        "        self.bias = torch.zeros(1, dtype=torch.float32, device=device)\n",
        "\n",
        "    def forward(self, x):\n",
        "        linear = torch.add(torch.mm(x, self.weights), self.bias)\n",
        "        probas = self._sigmoid(linear)\n",
        "        return probas\n",
        "        \n",
        "    def backward(self, probas, y):  \n",
        "        errors = y - probas.view(-1)\n",
        "        return errors\n",
        "            \n",
        "    def predict_labels(self, x):\n",
        "        probas = self.forward(x)\n",
        "        labels = custom_where(probas >= .5, 1, 0)\n",
        "        return labels    \n",
        "            \n",
        "    def evaluate(self, x, y):\n",
        "        labels = self.predict_labels(x).float()\n",
        "        accuracy = torch.sum(labels.view(-1) == y) / y.size()[0]\n",
        "        return accuracy\n",
        "    \n",
        "    def _sigmoid(self, z):\n",
        "        return 1. / (1. + torch.exp(-z))\n",
        "    \n",
        "    def _logit_cost(self, y, proba):\n",
        "        tmp1 = torch.mm(-y.view(1, -1), torch.log(proba))\n",
        "        tmp2 = torch.mm((1 - y).view(1, -1), torch.log(1 - proba))\n",
        "        return tmp1 - tmp2\n",
        "    \n",
        "    def train(self, x, y, num_epochs, learning_rate=0.01):\n",
        "        for e in range(num_epochs):\n",
        "            \n",
        "            #### Compute outputs ####\n",
        "            probas = self.forward(x)\n",
        "            \n",
        "            #### Compute gradients ####\n",
        "            errors = self.backward(probas, y)\n",
        "            neg_grad = torch.mm(x.transpose(0, 1), errors.view(-1, 1))\n",
        "            \n",
        "            #### Update weights ####\n",
        "            self.weights += learning_rate * neg_grad\n",
        "            self.bias += learning_rate * torch.sum(errors)\n",
        "            \n",
        "            #### Logging ####\n",
        "            print('Epoch: %03d' % (e+1), end=\"\")\n",
        "            print(' | Train ACC: %.3f' % self.evaluate(x, y), end=\"\")\n",
        "            print(' | Cost: %.3f' % self._logit_cost(y, self.forward(x)))"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "mlnt9D55AsL7",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 314
        },
        "outputId": "44bc8d4c-dafe-4124-f146-8f345f622c4c"
      },
      "source": [
        "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n",
        "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n",
        "\n",
        "logr = LogisticRegression1(num_features=2)\n",
        "logr.train(X_train_tensor, y_train_tensor, num_epochs=10, learning_rate=0.1)\n",
        "\n",
        "print('\\nModel parameters:')\n",
        "print('  Weights: %s' % logr.weights)\n",
        "print('  Bias: %s' % logr.bias)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001 | Train ACC: 0.000 | Cost: 5.581\n",
            "Epoch: 002 | Train ACC: 0.000 | Cost: 4.882\n",
            "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n",
            "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n",
            "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n",
            "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n",
            "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n",
            "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n",
            "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n",
            "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n",
            "\n",
            "Model parameters:\n",
            "  Weights: tensor([[ 4.2267],\n",
            "        [-2.9613]])\n",
            "  Bias: tensor([0.0994])\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/pytorch/aten/src/ATen/native/BinaryOps.cpp:81: UserWarning: Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UFdLvASiAsMA",
        "colab_type": "text"
      },
      "source": [
        "#### Evaluating the Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "d0TEqGPYAsMB",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "8ccb519e-3f82-4a90-c466-0ca6bdb551f9"
      },
      "source": [
        "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n",
        "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n",
        "\n",
        "test_acc = logr.evaluate(X_test_tensor, y_test_tensor)\n",
        "print('Test set accuracy: %.2f%%' % (test_acc*100))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test set accuracy: 100.00%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p7VHf3mOAsML",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 211
        },
        "outputId": "b06ebf90-d36f-4c32-847f-1d8a76e8ce45"
      },
      "source": [
        "##########################\n",
        "### 2D Decision Boundary\n",
        "##########################\n",
        "\n",
        "w, b = logr.weights, logr.bias\n",
        "\n",
        "x_min = -2\n",
        "y_min = ( (-(w[0] * x_min) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "x_max = 2\n",
        "y_max = ( (-(w[0] * x_max) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "\n",
        "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n",
        "\n",
        "ax[0].plot([x_min, x_max], [y_min, y_max])\n",
        "ax[1].plot([x_min, x_max], [y_min, y_max])\n",
        "\n",
        "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n",
        "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n",
        "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].legend(loc='upper left')\n",
        "plt.show()"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaoAAADCCAYAAAAYX4Z1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVyVdfr/8dcnxMANc1dQccsVFUGhrGkfraY028R9w6Zpqmn62tjor9XKxmZqmmwccFfUNM32sWxzMkHAHfcdcEfBDQQOn98fgKEelsO5z7nvc871fDx8POJwuLk0Lt7nvu/rfD5Ka40QQghhVdeZXYAQQghREQkqIYQQliZBJYQQwtIkqIQQQliaBJUQQghLk6ASQghhaTXM+KaNGjXSoaGhZnxrIQyRmpp6Smvd2Ow6SklPCW9QXl+ZElShoaGkpKSY8a2FMIRS6pALjx0ArAGup7hHP9Zav1zR10hPCW9QXl85HVTVaSohRIUuAXdqrc8rpfyBn5VSX2utE80uTAgzGHFGJU0lhIF08XIx50s+9C/5I0vICJ/ldFBJU3mWlRszmbZqF0eyc2lRP5AJ/ToyMDzY7LLEVZRSfkAq0B6YrrVOsvOc8cB4gFatWrm3QCHcyJB7VFVpKmG+lRszeXHFVnILbABkZufy4oqtABJWFqO1tgE9lVL1gU+UUt201tuuek4cEAcQGRl5zYvDgoICMjIyyMvLc0vNniIgIICQkBD8/f3NLkVUkSFBVZWmkld/5pu2atflkCqVW2Bj2qpdElQWpbXOVkr9APQHtlX2/LIyMjKoW7cuoaGhKKVcU6CH0VqTlZVFRkYGbdq0MbscjzLjp33Uub4Gw6Jbu/17G/o+Kq11NlDaVFd/Lk5rHam1jmzc2DJTvT7lSHauQ497ssNZF8nJLTC7jGpRSjUuedGHUioQuAfY6ehx8vLyaNiwoYRUGUopGjZsKGeZDopfs5+pX+8k+eBpzNhxw+mgMqqphOu1qB/o0OOe6lDWBR6PW8ezSzaaXUp1NQd+UEptAZKBb7XWX1TnQBJS15J/E8fM/vkAb3y1g/vDmvP3R3uY8u9nxBmVYU0lXGtCv44E+vtd8Vigvx8T+nU0qSLjHc66SExcInkFNv7Sv5PZ5VSL1nqL1jpca91da91Na/2a2TUZ6ZVXXuGdd95xybFTU1MJCwujffv2PPPMM6a8+vcm89cd5LUvttO/azPeG9yTGn7mLGbk9Hf19qbyJgPDg3lrUBjB9QNRQHD9QN4aFOY196cOZ11kcNw6cgtsJIyLpnPzemaXJNzsySefJD4+nj179rBnzx7++9//ml2Sx0pIOsRLn6ZxT5emvB8Tjr9JIQUmrUwhzDMwPNhrgqms9NMXiYlP5GKBjYRxUXRpISHlKFe8dWH+/Pm88847KKXo3r07CxYsuOLz8fHxxMXFkZ+fT/v27VmwYAG1atVi2bJlvPrqq/j5+REUFMSaNWtIS0tj9OjR5OfnU1RUxPLly+nQocPlYx09epSzZ88SHR0NwIgRI1i5ciX33nuvU38HX/RR8mEmfbKNuzo1YfqQXtSsYe6ysBJUwuOln77I4LhEzl8qJGFcFF1bBJldksdxxVsX0tLSmDJlCr/88guNGjXi9OnT1zxn0KBBxMbGAjB58mRmzZrF008/zWuvvcaqVasIDg4mOzsbgBkzZvDss88ydOhQ8vPzsdmunGDNzMwkJCTk8schISFkZmZWq3ZftiwlnYkrtnJ7x8Z8OMz8kAJZPV14uIwzxWdSpSHVLVhCqjoqeutCdX3//fc8+uijNGrUCIAGDRpc85xt27Zx6623EhYWRkJCAmlpaQD07duXUaNGER8ffzmQbrrpJt58803efvttDh06RGCgdw0BWcEnGzN4YfkWbmnfiBnDIri+hl/lX+QGElTCY2Vm5zI4LpGzuQUSUk4y660Lo0aN4oMPPmDr1q28/PLLl8fGZ8yYwZQpU0hPTyciIoKsrCyGDBnCZ599RmBgIPfddx/ff//9FccKDg4mIyPj8scZGRkEB1fvbHDlxkz6Tv2eNhO/pO/U71m50fvPzD7dlMnzSzdzU9uGxI+IJMDfGiEFElTCQxWH1LqSkIqWkHKSK966cOedd7Js2TKysrIA7F76O3fuHM2bN6egoICEhITLj+/bt4+oqChee+01GjduTHp6Ovv376dt27Y888wzDBgwgC1btlxxrObNm1OvXj0SExPRWjN//nwGDBjgcN2ll0Ezs3PR/HoZ1JvD6ostR3juo030Dm3AzJHWCimQoBIe6Eh2LjFxiWRfLGDhuCjCQiSknOWKty507dqVSZMmcdttt9GjRw/+/Oc/X/Oc119/naioKPr27UunTr++nWDChAmEhYXRrVs3br75Znr06MHSpUvp1q0bPXv2ZNu2bYwYMeKa43344YeMGzeO9u3b065du2oNUrjiMqiVfb31KM8u2URE6xuYPao3tWpab3RBmfE+g8jISC1754jqOJpTfLnv9Pl8Fo6LokfL+qbUoZRK1VpHmvLN7bDXUzt27KBz585VPoYvLVhc0b9Nm4lf2l1VWwEHpt7v0rrc7Zu0Y/whYQPdQ4KYPzaKOtebG1Ll9ZX1olOIchzLybscUgtMDClv5a1vXXBUi/qBZNq5N+dtK7h8t+M4Ty3aQLfgIOaN6WN6SFVELv0Jj1AcUuvIOp/P/LF96CkhJVzEF1Zw+WHXCZ5cuIHOzesxb0wf6gZYeyV560aoECWOn80jJj6RUyUhFd7qBrNLEhZ25mI+x3PyyLcVUdPvOpoGBXBDrZpV/vrSs0pvvQy6ZvdJnliQSoemdVgwJoqgQGuHFEhQCYs7fjaPmLhETpzNY/7YKHpJSIkKnLmYT+aZXIpK7r3n24rIPFN8Gc/RsPKWYCpr7d5TxM5PoV3jOiwcG0VQLeuHFMilP2FhJ0rOpI6fzWP+2D5EtJaQEhU7npN3OaRKFWnN8RzZ1mPdvizGzkumTaPaJIyL4obaVQ9us8kZlbCkE+eKQ+pYTh7zx/QhovW1qxoIcbV8W5FDj/uK9QdOM2ZuMi1vqMXCcVE08KCQAjmjEhZ08twlYuISOZqTx7wxfYgMlZDyJq7c5uPDaVP4bZ+uRHcMueLxmiau/G221EOnGT1nPS3qB7AoNppGda43uySH+e7/PWFJJ89dIia+OKTmju5Dbwkp4YBBDw1g8RdXLq10nVI0DQowqSJzbTx8hpGzk2laL4DFsdE0rut5IQVy6U9YyKnzlxgSn0jmmVzmju5NnzYSUm7zZjDkn7/28Zp14K/VXzrIndt8APz29ls5czGf0j1oqzP15y02p2czYtZ6GtapyaLYaJrU89ywlqASgPmrEpSGVMaZXOaM7k1U24Zu+95Wo5RqCcwHmgIaiNNa/9Ol39ReSFX0eBW4e5uPUjfUqolS0D3Ed99rty0zh+Gzkqhf25/FsdE08/AzSqcv/SmlWiqlflBKbVdKpSmlnjWiMOE+Zi/CmXX+EkPjkzh8+iKzRkUS7cMhVaIQeF5r3QWIBp5SSnUxuSaHyTYf5kg7ksPQmUnUDSgOKW9YUcOIe1Re0VS+zMxFOLPOX2LozCQOnb7A7JG9ubldI5d/T6vTWh/VWm8o+e9zwA7A+97Ug7HbfAjYeewsw2YmUbumH4tjowm5oZbZJRnC6aDypabyVmbtRXT6Qj5DZyZx4NQFZo3szc3tJaSuppQKBcKBJDufG6+USlFKpZw8edLdpVXK3dt8+Lrdx88xND6J62v4sSg2mlYNvSOkwOCpP09uKl/mir2IKnPmqpDqKyF1DaVUHWA58Cet9dmrP6+1jtNaR2qtIxs3buz+AithxjYfL7zwAiEhIVy8eJGQkBBeeeUVV/4VLWPviXMMiU/E7zrFotgoQhvVNrskQxm2zUdJU/0EvKG1XlHRc2WbD2spvUdV9vJfoL8fbw0Kc8lARWlI7T15nlkjI7m1g/V+yVbG1dt8KKX8gS+AVVrrf1T2fKe3+XDR1J9VOboFipXtO3mewXGJaA1LxkfTvkkds0uqNpdu81HSVMuBhMpCSlTf5JVbWZyUjk1r/JQiJqolUwaGOX1cdy7CmX0xn2GzikNq5gjPDClXU0opYBawoyohZQgvDCNfcPDUBYbEJ1JUpD0+pCridFCZ0lQ+aPLKrSxMPHz5Y5vWlz82KqxcPY5eGlJ7TpwnfkQkv7lRQqocfYHhwFal1KaSx/6qtf7KxJqExRzOukhMfCIFNs3i2Gg6NK1rdkkuY8Q9qtKmulMptankz30GHFeUsTgp3aHHrSbnYgHDZ61n97HzxA2P4DYJqXJprX/WWiutdXetdc+SPxJS4rL008UhlVtgY+HYKDo2896QAgPOqLTWP8PlN4ILF7GVcy+xvMetJCe3gOGzk9h17Bz/GR7B7R2bmF2Sz9BaU3zRQ5Qy6r68WTKzc4mJT+RcXgGLYqPp0qKe2SW5nKz15yH8yvllU97jVpGTW8CIWUnsPHqOGcN7cUcnCSl3CQgIICsry+N/MRtJa01WVhYBAZ65UsPRnFxi4hLJyS1g4bgougUHmV2SW8gSSh4iJqrlFfeoyj5uVWfzChgxez3bj55lxrAI7uzU1OySfEpISAgZGRl489tBsi/mc+GSDU3xZZ3a1/tRv5J1/QICAggJCanwOVZ0/GweQ+KTOHMhnwXjonxqiSgJKjdyZj290oEJZ6f+3LWm39m8AkbMWs/2Izn8e2gEd3WWkHI3f39/2rRpY3YZLlM8YJRxzePDolsZMmBkJaX7s5XudN2zpe+EFEhQuc3V71UqXU8PcCisnGlAI2qoinN5BYycvZ60Izl8ODSCu7tISAnjVTRg5E1BVbxgcxLHSvZn88WdruUelZuYuZ6eO2soDamtGTlMH9KLeySkhIt48oBRVZUu2Jx5Jpc5o3r77P5sElRuYtZ6eu6s4fylQkbNSWZLRg4fDOnFb7s2M+S4QtjjqQNGVVW6gsvBrAvMGhnp01vfSFC5iRnr6bmzhvOXChk1ez2b07P5YEg4/btJSAnXKm+QyJEBo5UbM+k79XvaTPySvlO/d9vWNpXJvlgcUvtPXWDmyEifX7BZgspNJvTriP91V77S879OMaFfR7vPd0UDTejXkUB/vyseC/T3K7eGqjp/qZDRc9azMT2bf8WE079bc6eOJ0RVTBkYxrDoVpfPoPyUcmiQwux92MqTk1v85vi9J4rfHC/LjMkwhXtdfUWinCsUrhp6cMWafhcuFTJmTjIbDheH1L1hElLCfZwZMKronq07d7cuq/QtHTuPnZU3x5chQeUm01btosB25U3eApu22xSubCAj1/S7cKmQ0XOTST18hvcHh3OfhJTwIFa4b1zW5WnZzBz+Le87vIJc+nMTR5rCag1kz8X8QsbMTSb10Bnee7wn93eXkBKexQr3jUtduFTI6DKDSDIteyUJKjdxpCms1ED2lIZU8sHTvPt4Tx7o0cLskoRwmKvu2TrqYn7xlYmN6dm8P1gGkeyRoHITR5rCKg1kT26+jbFzU1h/oDikHpSQEh5qYHgwbw0KI7h+IAoIrh/oss1Cy1PaTyklL/rkyoR9Pn2PylXLCVV03Kp8P3duZOiI3HwbY+clk3Qgi3cf78mAnubWI4Sz3LEPW3nyCmzEzk8h8UAW7z4mL/oq4rNB5arJusqOW9Vjm9lA9uQV2Bg3P5l1+7P4x2M9JKSEcEJegY3xC1JZu+8U0x7pYaletyKfvfTnquWErLBUktHyCmyMm5fCL/uy+PujPXgo3PNWnhbCKi4V2vhDwgbW7D7J24O680iE9FNlfDaoXDVZ5wkTe44ovTxR+spvUC9pKndQSs1WSp1QSm0zuxZhnPzCIp5K2Mj3O0/w5kNhPNbbutv0WIkhQeWJTeWqyTqrT+w5ovTyxM97T/G3h+WVn5vNBfqbXYQwToGtiKcXb2D1juO8PrAbQ6JamV2SxzDqjGouHtZUrpqss/LEniPyCmw8sSCV/+05ydsPd+fRSHnl505a6zXAabPrsDpXrdVn9HELbUU8u2Qjq9KO88oDXRge3dqQOn2FIcMUWus1SqlQI47lLq6arBsYHsyylMOs3ffr75herYIYGB5sdxrQFTU4K6/Axu8XpvLT7pP87eHuPCYhZUlKqfHAeIBWrXzv1blZA1GOKrQV8dzSzXy19RiT7+/MqL7eu5mlqyht0N4tJUH1hda6WzmfL9tUEYcOHTLk+1pN8a6j124Z37ddAzYczrli0MLfT4GGgqJf/x8E+vu5/b0cZV0qtPH7Ban8sOskUweFMbiP7/0CrAqlVKrWOtLF3yOUCnqqrMjISJ2SkuLKciyn79TvybRz7ze4fiBrJ95piePaijTPL93Eyk1HePHeTjxxW7tq1+ULyusrtw1TaK3jtNaRWuvIxo29dzXg8nYdXbvv9DXTgAU2fUVIgbkTgpcKbTy5cAM/7DrJWxJSwuKsPhBlK9JM+HgzKzcdYUK/jhJSTvDZqT9XMWJ3UTMmBC8V2vjDwg2Xp5FiJKSExVl5IKqoSPPiii2s2JDJn++5kafuaO9UTb5OgspgRuwu6u4JweKR2Q18t/MEU2QayRKUUouBdUBHpVSGUmqs2TVZjVUHooqKNJNWbmVpSgbP3NWBZ+7q4FQ9wqBhipKmuh1opJTKAF7WWs8y4tieJiaqpdP3qNw5IZhfWMRTizawescJXh/YjWEyjWQJWusYs2uwuoHhwaQcOs3ipHRsWuOnFA9HOLaii7PLnV1Na81Ln21j8fp0nrqjHc/dLSFlBKOm/jyyqRxZ62/yyq1XNERMVEu7G7ZNGRhG0v4s9py4cPmxDk1qkxB7k93v50ijGb02YYGtiD8u2sC324/z2oCuVRuZfTMY8s9f+3jNOvBXa2zjLXzDyo2ZLE/NvHy53aY1y1MziWzdoEp9YdRyZ6W01rz6+XYWJh7midva8n+/7Ygy4AqL8OFLf45sQ106yVe2IRYmHmbyyq12n1s2pAD2nLjA5JVbGRgezNqJd3Jg6v2Xp4fsNZq9GozeNrs0pL7ZfpxXH+zKiJtCq/aF9kKqoseFcBFnlyszcrkzrTWvf7GDub8cZNwtbZjYv5OElIF8Nqgc+SEtb5LP3uOOPNeRGoxsqgJbEc8s/vXNhyNvDnX4GEKYzdnpPKOm+7TWvPX1TmavPcCom0OZdH9nCSmD+WxQOfJDWt4kn73HHXmuGbv+FpS8Q/7rbcd46Xdd5M2HwmM5O51nxHSf1pppq3YRt2Y/w6Nb8/IDXSSkXMBng8qRH9LyJvnsPe7Ic92962+hrYg/LdnEV1uP8f9+14Uxt0hICc/l7HSeEVOD767ew4c/7mNIVCtefbCrhJSL+GxQOfJDGhNlfwkhe4878lx37vpbaCvi2Y828eXWo0y+vzNjJaSEh3N2h15nv/6fq/fw/nd7eDyyJVMGdOO660wMqTeD4ZWga/+86R37XPnsxomOjKCWTvdVdervwMnzV6z117ddA6YMDLM7tffWoDCX7/pbutbYl1uOMum+zoy7tW3V/pHsqVmn/Kk/IdzM2Q1Gq/v103/Yy7urd/NIRAhvDQozN6TA64ecDFvrzxHevC7Z1SOvUHzm83BEMMtTM6953NXr+hXaivjz0s18tvkIf72vE+N/I8u4GMEda/05wpt7ympm/LSPqV/v5KHwYN55tAd+ZocUFJ89lfu5HPfV4STT1/rzFeVN5y1OSnf7zr+2Is3zy4pDauK9ElJCOGvm//Yz9eudPNCjBdMe6W6NkPIBElQGK28Kr7xpQFet61e6avOnm47wl/6d+L0siCmEU+asPcCUL3dwf1hz3n2sBzX85Nenu/jsPSpXaVE/0O4WAX5K2Q0rV6zrZyvSTFj266rNT94uISXMY/SKKmaYv+4gr36+nX5dm/Le4J6eFVJXXxb0wFVkvC6oHGkKV2xkOKFfRyZ8vJkC26+h5O+neLx3S7v3qIxe1690a4EVGzOZ0K9j9Vdt9sWlknzx7+xirtrc0J0Skg7x0qdp3N25Kf+K6YW/FUOqvCEnezxwwMKrgsqRprD33AnLNoPicshUu6muPnHSENm6AZGtG7j0laWtSPPCx8VbCzzv7NYCXj5FZJcv/p1drKIVVTwhqD5KPsykT7ZxZ6cmTB8aTs0aFgwpsP9CqqIBCw/jVUHlSFPYe+7VmxhW9PUV1XD1cQqKit+9vnbinS5rzqIizcTlW1i+IYPn7r6Rp2VrAWEBrtrc0B2WpaQzccVWbruxMR8O7cX1Nfwq/yLhEhZ9eVA9RixJ5Mhxna3BKEVFmokrtrAsNYM/3d2BZ2VrAWERrtrc0NU+2ZjBC8u3cEv7RvxneAQB/hJSZvKqoDJiSSJHjutsDUYo3kn0103a/nT3jS75PkJUh6s2N3SlTzdl8vzSzdzUtiFxwyMlpCzAqy79TejX0e6bbctbkujq5/pfp664R1XR1xtRg7OKijR//WQrH6Wk88yd7WWTNmE5zqyoYoYvthzhuY820Tu0ATNHRhJY04NDyohVZBwZMHLhMJJRO/z2B/4J+AEztdZTjTiuoxxpCnu7gz7ep2W5Aw/lTRMOjV93zXJJVV0WyRnF211vY0lyOn+8oz3P3XOjsQtieuJSSc42ioX+zlbpKSM4u8yRu/x321GeXbKJiNY3MHtUb2rV9PDX8UZMqjoyYOTCYSSn/08opfyA6cA9QAaQrJT6TGu93dljV0dVm6Ki3UFLNzUs+1x704TTf9hzzSaJpaF19TGMVFSkmfzpNhavP8xTd7Tj+d8aHFLgmePYzjaKRf7OVuspX/BN2jH+uGgjPUKCmDO6D7Wv9/CQ8jJG3KPqA+zVWu/XWucDS4ABBhzXpYzYtPDqkCpV9gzLaFprXvpsG4uSDvPk7e1ku2vv5JE95am+23GcpxZtoFtwEPPG9KGOhJTlGBFUwUDZ7WszSh6zNFdNCLqS1pqXPk1jYeJhfn9bO17oJyHlparUU0qp8UqpFKVUysmTJ91WnDf5cdcJnly4gc7N6zFvTB/qBvibXZKww21Tf1ZrKldNCLqK1pqXP0tjQeIhnvhNW/7SX0LK12mt47TWkVrryMaNG5tdjsf5356TjF+QSoemdVgwJoqgQAkpqzIiqDKBsrsChpQ8dgWrNZURmxZ2aFLb7rH7tmtgXKEUh9Srn29n/rpDjP9NWybe20lCyrtVqadE9f2y9xTjZv1MW9sBFmYNJuhvjbxus0FDlDdIZO9xR57rICMuxiYDHZRSbShupsHAEAOOW6HypvCqutafoxOCy1IOX3HvqVerIBJib+Kef/x4xb2qDk1q82hkK/pO/d6Qqb/SkJr7y0HG3dKGF6saUo4sn+LsqKm718gr7/t5D1N6ylck7s9izLxkQtUxEmq+yQ3qqp8l7/7Zcowj/evCYSSng0prXaiU+iOwiuJR2tla6zSnK6tAeVN4KYdOX7Hwa2Vr9VV1QnDyyq3XDEis3XeaofHryDiTd8XjB7MuMmHZ5svLKDmzCKfWmte+KA6psbe0YdL9nV1zJuXsqKm718jz8l8kZvSUr1h/4DRj5ibT8oZaJOS8SUN1zuySRBUYco9Ka/2V1vpGrXU7rfUbRhyzIu7enHBxUrrdx9fuO33teoE2fc1af9WpQWvNlC93MGftQUb3DWWyq0JKWJK7e8oXpB46zeg562kWFEBCbBSN1FmzSxJV5JFLKLl7c8LyjusIR2rQWvPGlzuY9fMBRt0cyku/6yIhJYQTNh4+w8jZyTSpF8Di2Gia1A0wuyThAI8MqvKm8PzK+WXu7NReecd1RFVr0Frz1tc7mVkSUi8/ICElhDM2p2czYtZ6GtapyeLYaJrWk5DyNB75zrby1tN7OCLYJZsTxkS1ZGHi4Wse79uuARsO51y5XqCfAn3lliFVrUFrzdSvdxK3Zj8jbmotISWEk7Zl5jB8VhL1a/uzODaaZkFlQsoqS2ZZeXDJIjwyqCqb2Cu7ft/DEc6vMzZlYNg1x42JasmUgWGG7RKstebt/+7iP2v2Mzy6Na8+2NXckHJ3E7uyAa28PqFwmbQjOQydmUTdgOKQuuaqhlV+sVt5cMkiPDKowP7EXkXr9xkRVqWBVVkdpY9Xldaav63axYyf9jEsuhWvDTA5pMD9TWxUA76S43wtwuPtPHaWYTOTqF3Tj8Wx0YTcUMvskoQTPPIeVXkcWb/PKrTWvPPNLv794z6GRLXitQe7mR9SQniwPcfPMTQ+ietr+LEoNppWDSWkPJ1XBZWnbXuttebv3+xm+g/7iOnTiikDunHddRJSQlTX3hPniYlPwu86xaLYKEIb2V89RngWrwoqT9v2+t3Ve/jgh70M7t2SNwZKSAnhjP0nzzMkPhGARbHRtG0s9ya9hcfeo7LHnbvrOuvdb3fz/nd7eDyyJW8+FCYhVRFHloO6mo9OSfmag6cuEBOfiK1Is2R8NO2beFBIlTe4BFX/2ffygSGvCipP2fb6vdW7+ed3e3g0IoS3BrkopCr64bciI+q116w+OiXlSw5nXSQmPpH8wiKWjL+JDk3rml2SY+y9YKoooHxwYMirggqsv+31+9/t4b3Ve3gkIoS3H+7uujMpR3/4zVbe2Y00rKhA+unikMotsLFoXDQdm3lYSIkq8ap7VFb3r+/28I9vd/NwLxeHlBA+IDM7lyEzEzmXV8DCsVF0aVHP7JKEi3jdGZVVTf9hL3//djeDwoP52yPd8ZOQEqLajubkMiQ+keyLBSSMi6JbsIWvFginyRmVG0z/YS/TVu3iofBgpj3aQ0JKCCccP5vHkPgkss7nM39MH7qH1De7JOFickblYv/+cR/TVu1iYM8WvHN1SLlyIs3RzQXt3QtyZK0xe4yYRHJ2KSerrOcmDHHiXB4x8YmcOJvH/LF9CG91g9kl2edsb8vP7RUkqFxoxk/7ePu/OxnQswV/f6zntWdSrpxIq+gYVw8hlDew4MhaY/aOawRnA1tG0L3GqfOXGBKfxLGcPOaN6UNE6wZml1Q+Z3tbfm6vIJf+XCRuzT6mfr2TB3q04O9yuU8Ip2Sdv8TQ+CQyzlxk9qje9A61cEgJwzkVVEqpR5VSaUqpIqVUpFFFebr4NTT1FREAAA60SURBVPt586ud/K57c959rAc1/OT1gKga6alrnbmQz9CZSRzMusDskb2JbtvQ7JKEmzn7G3QbMAhYY0AtXmHm//bzxlc7uL97c957vKeElHCU9FQZ2ReLQ2r/qQvMHBnJze0bmV2SMIFT96i01jsAWe27xKyfDzDlyx3cH9acf0pIiWqQnvpVTm4Bw2etZ++J88SNiODWDo3NLkmYRIYpDDL75wO8/sV27u3WjPcGVzGkjJjscXS6D5xbP0ymkSxDKTUeGA/QqlUrk6sx1tm8AkbMXs/OY2f5z/AIbu/YxOyShIkqDSql1GqgmZ1PTdJaf1rVb+TNTTV37QFe+2I7/bs24/2YcPyreiZlxGSPEdN99p5bHplGcppRPaW1jgPiACIjI7VB5Znu/KVCRs1eT1pmDv8eFsGdnZqaXZIwWaVBpbW+24hv5K1NNe+Xg7zy+Xb6dW3Kv4Y4EFLCZxnVU97owqVCRs9Zz+aMHKYP6cU9XSSkhIynO2X+uoO8/Fkav+3SlH/F9JKQEsIJF/MLGT03mQ2Hs3l/cDj9u9k76RS+yNnx9IeUUhnATcCXSqlVxpRlfQsSD/HSp2nc06UpHwzpRc0aElLCeb7aU7n5NsbOTSHl4Gn+8VgP7u/e3OyShIU4O/X3CfCJQbV4jIWJh/h/K7dxd+cmTJeQEgbyxZ7KK7AROz+FxANZ/OOxHgzoadFtehxZFkmGjgwlU38OWpR0mMkrt3FXpyZMH+pASFlhp1krNI8V/h2EZeQV2HhiQSpr951i2iM9eCg8xOySyufIskjys2woCSoHLF5/mL9+spU7OzXhw2G9uL6GX9W/2Ao7zVqheazw7yAs4VKhjT8kbOCn3Sd5++EwHomwcEgJU8k1qypasv4wL67Yyh0dG/NvR0NKCHGF/MIinkrYyPc7T/DmQ2E83tu73rIijCVBVQVLk9OZuGIrt3dszL+HRUhICeGEAlsRTy/ewOodx3l9QFeGRElIiYpJUFViaUo6f1mxhd/c2JgZwyII8JeQEqK6Cm1FPLtkI6vSjvPyA10YflOo2SUJDyD3qCqwLCWdvyzfwi3tGxE33E5IyWCAEFVWaCviuaWb+WrrMSbf35nRfduYXZJ9VV1irJRM8rmcBFU5Pk7N4IWSkIofEWn/TMqRwQBXTdxZYZLPEZ5WrzCErUjzf8s28/nmI7x4byfG3drW7JKqxxWbg4pKSVDZsWJDBhM+3kzfdhWElKNcdYblaWdunlavcFpRkeaFj7ewctMRJvTryBO3tTO7JOFh5B7VVT7ZmMHzyzZzc7uGxoWUED6qqEjz4oqtLN+QwZ/vuZGn7mhvdknCA0lQlbFyYybPL93MTW0bMnNEbwJrSkgJUV1FRZpJK7fxUUo6z9zVgWfu6mB2ScJDSVCV+HRTJn9euomoNg2ZNVJCSghnaK15+bM0Fq8/zFN3tOO5uyWkRPX5xD2qlRszmbZqF0eyc2lRP5AJ/ToyMPzX9cQ+23yE5z7aRJ82DZg1KrLqIWWFwQBPmzz0tHqFw7TWvPr5dhYkHuKJ37Tl/37b0bkdi+Vnxud5fVCt3JjJiyu2kltgAyAzO5cXV2wFYGB4MJ9vPsKflmwkMrQBs0f1plZNB/5JrNAknrYkkafVKxyitWbKlzuY+8tBxt7Shon3dnIupMD9PzMy2Wc5Xh9U01btuhxSpXILbExbtYsafoo/fbSJyNYNmONoSAkhrqC1ZurXO5n18wFG3RzK5Ps7Ox9SQuAD96iOZOfafTwzO5dnl2yiV6v6zBndm9rXS0gJUV1aa6at2sV/1uxnWHQrXn6gi4SUMIzXB1WL+oHlfi68ZX3mjO4jISWEk95dvYcPf9xHTJ9WvPZgNwkpYSivD6oJ/ToSaOe9UG0a1mbumD7UkZASwin/XL2H97/bw2ORIbwxsBvXXSchJYzl1G9ppdQ04AEgH9gHjNZaZxtRmFFKp/umrdpFZsllwNCGtfjs6b7eEVJWmDwsT3nTWvZYoV4L8ISeKmv6D3t5d/VuHu4VwtRB3V0TUlb+GTeCTDVWytnf1N8CL2qtC5VSbwMvAn9xvixjDQwPplZNP/6QsIGwkCDmj+lD3QB/s8syhpV/kCsKKZmsKo9H9BTAjJ/2MW3VLgb2bMHfHnFRSIG1f8aNIJOwlXLq0p/W+hutdWHJh4mAJbfo/Hb7cZ5atIFuwUHM86aQEl7HU3pq5v/2M/XrnTzQowXvPNoDP7ncJ1zIyHtUY4Cvy/ukUmq8UipFKZVy8uRJA79txVZvP84fElLp0iKI+WP7UE9CSngOS/bUnLUHmPLlDu4Pa867j/Wghp/X3+oWJqv00p9SajXQzM6nJmmtPy15ziSgEEgo7zha6zggDiAyMlJXq1oHfbfjOE8mpNKleT3mj5GQEtbgyT21YN1BXv18O/26NuW9wT0lpIRbVBpUWuu7K/q8UmoU8DvgLq21W5qlKr7feZwnF26gc/N6zB8bRVCghJSwBk/tqUVJh/l/n6Zxd+em/CumF/4SUsJNnPpJU0r1B14AHtRaXzSmJOf9sPMEv1+wgY7N6rJgjISUacqbyvKWaS0XsGpPLU1O56+fbOXOTk2YPjScmjUkpAwjfVIpZ6f+PgCuB74teYNfotb6905X5YQfd53giQWp3NisDgvHRhFUS0LKNN4+reUaluupj1Mz+MuKLdx2Y2M+HNqL62vIzgKGkj6plFNBpbW21C5oP+0+yfgFqXRoKiElPJPVeuqTjb/udv2f4RGykagwhdecv6/ZfZLY+Sm0b1yHhHFR1K9V0+yShPBon20+wvNLNxPdRna7FubyiqD6357ikGonISWEIb7ccpTnPtpEZKiDe7QJ4QIev4bQz3tOMW5eCm0a1SZhXBQ31DY5pGQ5FOHh/rvtKM8s2Vi8aLNsfyMswKPPqNbuPcXYecm0aVSbRbHRNDA7pECWQxEe7Zu0Y/xx0UZ6hAQxd4zsLCCswWOD6pcyIZUwLsoaISWEB/tuR/FSY12Dg2RnAWEpHhlUv+w7xZh5ybRuUBxSDetcb3ZJQni0H3ed4MmFG+jUTFZxEdbjcUG1bl8WY+Ym06pBLRJiJaSEcNb/9vz6to4FY/vIG+SF5XhUUCXuLw6pljfUYlFsNI0kpIRwyi97i4eR2jaqzcKxMjErrMljgippfxaj5yQTfEOgtUNKlkMRHiJxfxZj5iUT2tAiE7NClMMj7pauP3Ca0XOTaVE/gEWxUTSua9GQAhlBFx4h+eDpy1cn5BK6sDrLn1ElHzzNqDnraRYUwOLYaJrUDTC7JCE8WuqhM4yaXdxTCbFR1r06IUQJSwdVysHTxQ1VL4AlsdE0qSchJYQzNh4+w8jZ62lST174Cc9h2aBKPXSakbPX07ReAIvHS0gJ4awtGdmMmL2ehnVqsjg2mqbSU8JDWDKoUg+dYeTs5OJXfeOloYRw1rbMHIbNTCIo0J9FsdE0C5KeEp7DckG1oeTSRCN51SeEIbYfOcuwWUnUDfBncWw0wfUDzS5JCIdYKqi01rz+xfbiSxPj5VWfEEZ46+sdBPr7sTg2mpYNapldjhAOc2o8XSn1OjAAKAJOAKO01kecOB7/GR5BoU3TPEhe9QlhhH/FhHM2t5BWDSWkhGdy9oxqmta6u9a6J/AF8JKzBTWpG0ALuTQhfJRS6nWl1Bal1Cal1DdKqRbOHrN+rZoSUsKjORVUWuuzZT6sDWjnyhHC5xn+4k8IT+f0yhRKqTeAEUAOcIfTFQnhw+TFnxDXqvSMSim1Wim1zc6fAQBa60la65ZAAvDHCo4zXimVopRKOXnypHF/AyG8jFLqDaVUOjCUCs6opKeEr1BaG/OCTSnVCvhKa92tsudGRkbqlJQUQ76vEGZQSqVqrSOr+bWrgWZ2PjVJa/1pmee9CARorV+u7JjSU8IblNdXzk79ddBa7yn5cACw05njCeELtNZ3V/GpCcBXQKVBJYQ3c+qMSim1HOhI8Xj6IeD3WutKlw9XSp0seX55GgGnql2Y8axWD1ivJqvVA66tqbXWurHRBy374k8p9TRwm9b6kSp8naf1FFivJqmncq6uyW5fGXbpz0hKqZTqXlZxBavVA9aryWr1gDVrqkx1X/xV4biW+7ewWk1ST+XMqskj9qMSwldorR82uwYhrMZSSygJIYQQV7NqUMWZXcBVrFYPWK8mq9UD1qzJLFb8t7BaTVJP5UypyZL3qIQQQohSVj2jEkIIIQALB5VSappSamfJAp2fKKXqm1zPo0qpNKVUkVLKtEkcpVR/pdQupdRepdREs+ooU89spdQJpdQ2s2sBUEq1VEr9oJTaXvL/61mza7IK6akKa7FMX1mtp8D8vrJsUAHfAt201t2B3cCLJtezDRgErDGrAKWUHzAduBfoAsQopbqYVU+JuUB/k2soqxB4XmvdBYgGnrLAv5FVSE/ZYcG+mou1egpM7ivLBpXW+hutdWHJh4lAiMn17NBa7zKzBqAPsFdrvV9rnQ8soXhFENNordcAp82soSyt9VGt9YaS/z4H7ACCza3KGqSnymWpvrJaT4H5fWXZoLrKGOBrs4uwgGAgvczHGcgv4XIppUKBcCDJ3EosSXrqV9JXDjCjr0x9w29VFudUSk2i+LQzwQr1CM+glKoDLAf+dNXWGV5Nekq4kll9ZWpQVbY4p1JqFPA74C7thjl6BxYLNUsm0LLMxyElj4kylFL+FDdTgtZ6hdn1uJP0VLVIX1WBmX1l2Ut/Sqn+wAvAg1rri2bXYxHJQAelVBulVE1gMPCZyTVZilJKAbOAHVrrf5hdj5VIT5VL+qoSZveVZYMK+ACoC3yrlNqklJphZjFKqYeUUhnATcCXSqlV7q6h5Eb4H4FVFN/MXKq1TnN3HWUppRYD64COSqkMpdRYM+sB+gLDgTtLfm42KaXuM7kmq5CessNqfWXBngKT+0pWphBCCGFpVj6jEkIIISSohBBCWJsElRBCCEuToBJCCGFpElRCCCEsTYJKCCGEpUlQCSGEsDQJKiGEEJb2/wF5x+N/Q0V3uwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 504x216 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pcdwOVLQAsMP",
        "colab_type": "text"
      },
      "source": [
        "## Low-level implementation using autograd"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x6k7am1rAsMQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def custom_where(cond, x_1, x_2):\n",
        "    return (cond * x_1) + ((1-cond) * x_2)\n",
        "\n",
        "\n",
        "class LogisticRegression2():\n",
        "    def __init__(self, num_features):\n",
        "        self.num_features = num_features\n",
        "        \n",
        "        self.weights = torch.zeros(num_features, 1, \n",
        "                                   dtype=torch.float32,\n",
        "                                   device=device,\n",
        "                                   requires_grad=True) # req. for autograd!\n",
        "        self.bias = torch.zeros(1, \n",
        "                                dtype=torch.float32,\n",
        "                                device=device,\n",
        "                                requires_grad=True) # req. for autograd!\n",
        "\n",
        "    def forward(self, x):\n",
        "        linear = torch.add(torch.mm(x, self.weights), self.bias)\n",
        "        probas = self._sigmoid(linear)\n",
        "        return probas\n",
        "                    \n",
        "    def predict_labels(self, x):\n",
        "        probas = self.forward(x)\n",
        "        labels = custom_where((probas >= .5).float(), 1, 0)\n",
        "        return labels    \n",
        "            \n",
        "    def evaluate(self, x, y):\n",
        "        labels = self.predict_labels(x)\n",
        "        accuracy = (torch.sum(labels.view(-1) == y.view(-1))).float() / y.size()[0]\n",
        "        return accuracy\n",
        "    \n",
        "    def _sigmoid(self, z):\n",
        "        return 1. / (1. + torch.exp(-z))\n",
        "    \n",
        "    def _logit_cost(self, y, proba):\n",
        "        tmp1 = torch.mm(-y.view(1, -1), torch.log(proba))\n",
        "        tmp2 = torch.mm((1 - y).view(1, -1), torch.log(1 - proba))\n",
        "        return tmp1 - tmp2\n",
        "    \n",
        "    def train(self, x, y, num_epochs, learning_rate=0.01):\n",
        "        \n",
        "        for e in range(num_epochs):\n",
        "            \n",
        "            #### Compute outputs ####\n",
        "            proba = self.forward(x)\n",
        "            cost = self._logit_cost(y, proba)\n",
        "            \n",
        "            #### Compute gradients ####\n",
        "            cost.backward()\n",
        "            \n",
        "            #### Update weights ####\n",
        "            \n",
        "            tmp = self.weights.detach()\n",
        "            tmp -= learning_rate * self.weights.grad\n",
        "            \n",
        "            tmp = self.bias.detach()\n",
        "            tmp -= learning_rate * self.bias.grad\n",
        "            \n",
        "            #### Reset gradients to zero for next iteration ####\n",
        "            self.weights.grad.zero_()\n",
        "            self.bias.grad.zero_()\n",
        "    \n",
        "            #### Logging ####\n",
        "            print('Epoch: %03d' % (e+1), end=\"\")\n",
        "            print(' | Train ACC: %.3f' % self.evaluate(x, y), end=\"\")\n",
        "            print(' | Cost: %.3f' % self._logit_cost(y, self.forward(x)))\n",
        "            \n"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "TASOgZKVAsMV",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 276
        },
        "outputId": "cdbeeb00-c821-45d1-f8b8-2053f2d09cbd"
      },
      "source": [
        "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n",
        "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n",
        "\n",
        "logr = LogisticRegression2(num_features=2)\n",
        "logr.train(X_train_tensor, y_train_tensor, num_epochs=10, learning_rate=0.1)\n",
        "\n",
        "print('\\nModel parameters:')\n",
        "print('  Weights: %s' % logr.weights)\n",
        "print('  Bias: %s' % logr.bias)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001 | Train ACC: 0.987 | Cost: 5.581\n",
            "Epoch: 002 | Train ACC: 0.987 | Cost: 4.882\n",
            "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n",
            "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n",
            "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n",
            "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n",
            "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n",
            "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n",
            "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n",
            "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n",
            "\n",
            "Model parameters:\n",
            "  Weights: tensor([[ 4.2267],\n",
            "        [-2.9613]], requires_grad=True)\n",
            "  Bias: tensor([0.0994], requires_grad=True)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ny4rD37TAsMa",
        "colab_type": "text"
      },
      "source": [
        "#### Evaluating the Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "YWRhbDICAsMa",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "ddb020ca-45c9-4cc8-e8a9-6c2265ce592d"
      },
      "source": [
        "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n",
        "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n",
        "\n",
        "test_acc = logr.evaluate(X_test_tensor, y_test_tensor)\n",
        "print('Test set accuracy: %.2f%%' % (test_acc*100))"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test set accuracy: 100.00%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ob8uRBqPAsMf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 211
        },
        "outputId": "128ee7de-29ce-4861-f217-020b5d70073b"
      },
      "source": [
        "##########################\n",
        "### 2D Decision Boundary\n",
        "##########################\n",
        "\n",
        "w, b = logr.weights, logr.bias\n",
        "\n",
        "x_min = -2\n",
        "y_min = ( (-(w[0] * x_min) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "x_max = 2\n",
        "y_max = ( (-(w[0] * x_max) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "\n",
        "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n",
        "\n",
        "ax[0].plot([x_min, x_max], [y_min, y_max])\n",
        "ax[1].plot([x_min, x_max], [y_min, y_max])\n",
        "\n",
        "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n",
        "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n",
        "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].legend(loc='upper left')\n",
        "plt.show()"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaoAAADCCAYAAAAYX4Z1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVyVdfr/8dcnxMANc1dQccsVFUGhrGkfraY028R9w6Zpqmn62tjor9XKxmZqmmwccFfUNM32sWxzMkHAHfcdcEfBDQQOn98fgKEelsO5z7nvc871fDx8POJwuLk0Lt7nvu/rfD5Ka40QQghhVdeZXYAQQghREQkqIYQQliZBJYQQwtIkqIQQQliaBJUQQghLk6ASQghhaTXM+KaNGjXSoaGhZnxrIQyRmpp6Smvd2Ow6SklPCW9QXl+ZElShoaGkpKSY8a2FMIRS6pALjx0ArAGup7hHP9Zav1zR10hPCW9QXl85HVTVaSohRIUuAXdqrc8rpfyBn5VSX2utE80uTAgzGHFGJU0lhIF08XIx50s+9C/5I0vICJ/ldFBJU3mWlRszmbZqF0eyc2lRP5AJ/ToyMDzY7LLEVZRSfkAq0B6YrrVOsvOc8cB4gFatWrm3QCHcyJB7VFVpKmG+lRszeXHFVnILbABkZufy4oqtABJWFqO1tgE9lVL1gU+UUt201tuuek4cEAcQGRl5zYvDgoICMjIyyMvLc0vNniIgIICQkBD8/f3NLkVUkSFBVZWmkld/5pu2atflkCqVW2Bj2qpdElQWpbXOVkr9APQHtlX2/LIyMjKoW7cuoaGhKKVcU6CH0VqTlZVFRkYGbdq0MbscjzLjp33Uub4Gw6Jbu/17G/o+Kq11NlDaVFd/Lk5rHam1jmzc2DJTvT7lSHauQ497ssNZF8nJLTC7jGpRSjUuedGHUioQuAfY6ehx8vLyaNiwoYRUGUopGjZsKGeZDopfs5+pX+8k+eBpzNhxw+mgMqqphOu1qB/o0OOe6lDWBR6PW8ezSzaaXUp1NQd+UEptAZKBb7XWX1TnQBJS15J/E8fM/vkAb3y1g/vDmvP3R3uY8u9nxBmVYU0lXGtCv44E+vtd8Vigvx8T+nU0qSLjHc66SExcInkFNv7Sv5PZ5VSL1nqL1jpca91da91Na/2a2TUZ6ZVXXuGdd95xybFTU1MJCwujffv2PPPMM6a8+vcm89cd5LUvttO/azPeG9yTGn7mLGbk9Hf19qbyJgPDg3lrUBjB9QNRQHD9QN4aFOY196cOZ11kcNw6cgtsJIyLpnPzemaXJNzsySefJD4+nj179rBnzx7++9//ml2Sx0pIOsRLn6ZxT5emvB8Tjr9JIQUmrUwhzDMwPNhrgqms9NMXiYlP5GKBjYRxUXRpISHlKFe8dWH+/Pm88847KKXo3r07CxYsuOLz8fHxxMXFkZ+fT/v27VmwYAG1atVi2bJlvPrqq/j5+REUFMSaNWtIS0tj9OjR5OfnU1RUxPLly+nQocPlYx09epSzZ88SHR0NwIgRI1i5ciX33nuvU38HX/RR8mEmfbKNuzo1YfqQXtSsYe6ysBJUwuOln77I4LhEzl8qJGFcFF1bBJldksdxxVsX0tLSmDJlCr/88guNGjXi9OnT1zxn0KBBxMbGAjB58mRmzZrF008/zWuvvcaqVasIDg4mOzsbgBkzZvDss88ydOhQ8vPzsdmunGDNzMwkJCTk8schISFkZmZWq3ZftiwlnYkrtnJ7x8Z8OMz8kAJZPV14uIwzxWdSpSHVLVhCqjoqeutCdX3//fc8+uijNGrUCIAGDRpc85xt27Zx6623EhYWRkJCAmlpaQD07duXUaNGER8ffzmQbrrpJt58803efvttDh06RGCgdw0BWcEnGzN4YfkWbmnfiBnDIri+hl/lX+QGElTCY2Vm5zI4LpGzuQUSUk4y660Lo0aN4oMPPmDr1q28/PLLl8fGZ8yYwZQpU0hPTyciIoKsrCyGDBnCZ599RmBgIPfddx/ff//9FccKDg4mIyPj8scZGRkEB1fvbHDlxkz6Tv2eNhO/pO/U71m50fvPzD7dlMnzSzdzU9uGxI+IJMDfGiEFElTCQxWH1LqSkIqWkHKSK966cOedd7Js2TKysrIA7F76O3fuHM2bN6egoICEhITLj+/bt4+oqChee+01GjduTHp6Ovv376dt27Y888wzDBgwgC1btlxxrObNm1OvXj0SExPRWjN//nwGDBjgcN2ll0Ezs3PR/HoZ1JvD6ostR3juo030Dm3AzJHWCimQoBIe6Eh2LjFxiWRfLGDhuCjCQiSknOWKty507dqVSZMmcdttt9GjRw/+/Oc/X/Oc119/naioKPr27UunTr++nWDChAmEhYXRrVs3br75Znr06MHSpUvp1q0bPXv2ZNu2bYwYMeKa43344YeMGzeO9u3b065du2oNUrjiMqiVfb31KM8u2URE6xuYPao3tWpab3RBmfE+g8jISC1754jqOJpTfLnv9Pl8Fo6LokfL+qbUoZRK1VpHmvLN7bDXUzt27KBz585VPoYvLVhc0b9Nm4lf2l1VWwEHpt7v0rrc7Zu0Y/whYQPdQ4KYPzaKOtebG1Ll9ZX1olOIchzLybscUgtMDClv5a1vXXBUi/qBZNq5N+dtK7h8t+M4Ty3aQLfgIOaN6WN6SFVELv0Jj1AcUuvIOp/P/LF96CkhJVzEF1Zw+WHXCZ5cuIHOzesxb0wf6gZYeyV560aoECWOn80jJj6RUyUhFd7qBrNLEhZ25mI+x3PyyLcVUdPvOpoGBXBDrZpV/vrSs0pvvQy6ZvdJnliQSoemdVgwJoqgQGuHFEhQCYs7fjaPmLhETpzNY/7YKHpJSIkKnLmYT+aZXIpK7r3n24rIPFN8Gc/RsPKWYCpr7d5TxM5PoV3jOiwcG0VQLeuHFMilP2FhJ0rOpI6fzWP+2D5EtJaQEhU7npN3OaRKFWnN8RzZ1mPdvizGzkumTaPaJIyL4obaVQ9us8kZlbCkE+eKQ+pYTh7zx/QhovW1qxoIcbV8W5FDj/uK9QdOM2ZuMi1vqMXCcVE08KCQAjmjEhZ08twlYuISOZqTx7wxfYgMlZDyJq7c5uPDaVP4bZ+uRHcMueLxmiau/G221EOnGT1nPS3qB7AoNppGda43uySH+e7/PWFJJ89dIia+OKTmju5Dbwkp4YBBDw1g8RdXLq10nVI0DQowqSJzbTx8hpGzk2laL4DFsdE0rut5IQVy6U9YyKnzlxgSn0jmmVzmju5NnzYSUm7zZjDkn7/28Zp14K/VXzrIndt8APz29ls5czGf0j1oqzP15y02p2czYtZ6GtapyaLYaJrU89ywlqASgPmrEpSGVMaZXOaM7k1U24Zu+95Wo5RqCcwHmgIaiNNa/9Ol39ReSFX0eBW4e5uPUjfUqolS0D3Ed99rty0zh+Gzkqhf25/FsdE08/AzSqcv/SmlWiqlflBKbVdKpSmlnjWiMOE+Zi/CmXX+EkPjkzh8+iKzRkUS7cMhVaIQeF5r3QWIBp5SSnUxuSaHyTYf5kg7ksPQmUnUDSgOKW9YUcOIe1Re0VS+zMxFOLPOX2LozCQOnb7A7JG9ubldI5d/T6vTWh/VWm8o+e9zwA7A+97Ug7HbfAjYeewsw2YmUbumH4tjowm5oZbZJRnC6aDypabyVmbtRXT6Qj5DZyZx4NQFZo3szc3tJaSuppQKBcKBJDufG6+USlFKpZw8edLdpVXK3dt8+Lrdx88xND6J62v4sSg2mlYNvSOkwOCpP09uKl/mir2IKnPmqpDqKyF1DaVUHWA58Cet9dmrP6+1jtNaR2qtIxs3buz+AithxjYfL7zwAiEhIVy8eJGQkBBeeeUVV/4VLWPviXMMiU/E7zrFotgoQhvVNrskQxm2zUdJU/0EvKG1XlHRc2WbD2spvUdV9vJfoL8fbw0Kc8lARWlI7T15nlkjI7m1g/V+yVbG1dt8KKX8gS+AVVrrf1T2fKe3+XDR1J9VOboFipXtO3mewXGJaA1LxkfTvkkds0uqNpdu81HSVMuBhMpCSlTf5JVbWZyUjk1r/JQiJqolUwaGOX1cdy7CmX0xn2GzikNq5gjPDClXU0opYBawoyohZQgvDCNfcPDUBYbEJ1JUpD0+pCridFCZ0lQ+aPLKrSxMPHz5Y5vWlz82KqxcPY5eGlJ7TpwnfkQkv7lRQqocfYHhwFal1KaSx/6qtf7KxJqExRzOukhMfCIFNs3i2Gg6NK1rdkkuY8Q9qtKmulMptankz30GHFeUsTgp3aHHrSbnYgHDZ61n97HzxA2P4DYJqXJprX/WWiutdXetdc+SPxJS4rL008UhlVtgY+HYKDo2896QAgPOqLTWP8PlN4ILF7GVcy+xvMetJCe3gOGzk9h17Bz/GR7B7R2bmF2Sz9BaU3zRQ5Qy6r68WTKzc4mJT+RcXgGLYqPp0qKe2SW5nKz15yH8yvllU97jVpGTW8CIWUnsPHqOGcN7cUcnCSl3CQgIICsry+N/MRtJa01WVhYBAZ65UsPRnFxi4hLJyS1g4bgougUHmV2SW8gSSh4iJqrlFfeoyj5uVWfzChgxez3bj55lxrAI7uzU1OySfEpISAgZGRl489tBsi/mc+GSDU3xZZ3a1/tRv5J1/QICAggJCanwOVZ0/GweQ+KTOHMhnwXjonxqiSgJKjdyZj290oEJZ6f+3LWm39m8AkbMWs/2Izn8e2gEd3WWkHI3f39/2rRpY3YZLlM8YJRxzePDolsZMmBkJaX7s5XudN2zpe+EFEhQuc3V71UqXU8PcCisnGlAI2qoinN5BYycvZ60Izl8ODSCu7tISAnjVTRg5E1BVbxgcxLHSvZn88WdruUelZuYuZ6eO2soDamtGTlMH9KLeySkhIt48oBRVZUu2Jx5Jpc5o3r77P5sElRuYtZ6eu6s4fylQkbNSWZLRg4fDOnFb7s2M+S4QtjjqQNGVVW6gsvBrAvMGhnp01vfSFC5iRnr6bmzhvOXChk1ez2b07P5YEg4/btJSAnXKm+QyJEBo5UbM+k79XvaTPySvlO/d9vWNpXJvlgcUvtPXWDmyEifX7BZgspNJvTriP91V77S879OMaFfR7vPd0UDTejXkUB/vyseC/T3K7eGqjp/qZDRc9azMT2bf8WE079bc6eOJ0RVTBkYxrDoVpfPoPyUcmiQwux92MqTk1v85vi9J4rfHC/LjMkwhXtdfUWinCsUrhp6cMWafhcuFTJmTjIbDheH1L1hElLCfZwZMKronq07d7cuq/QtHTuPnZU3x5chQeUm01btosB25U3eApu22xSubCAj1/S7cKmQ0XOTST18hvcHh3OfhJTwIFa4b1zW5WnZzBz+Le87vIJc+nMTR5rCag1kz8X8QsbMTSb10Bnee7wn93eXkBKexQr3jUtduFTI6DKDSDIteyUJKjdxpCms1ED2lIZU8sHTvPt4Tx7o0cLskoRwmKvu2TrqYn7xlYmN6dm8P1gGkeyRoHITR5rCKg1kT26+jbFzU1h/oDikHpSQEh5qYHgwbw0KI7h+IAoIrh/oss1Cy1PaTyklL/rkyoR9Pn2PylXLCVV03Kp8P3duZOiI3HwbY+clk3Qgi3cf78mAnubWI4Sz3LEPW3nyCmzEzk8h8UAW7z4mL/oq4rNB5arJusqOW9Vjm9lA9uQV2Bg3P5l1+7P4x2M9JKSEcEJegY3xC1JZu+8U0x7pYaletyKfvfTnquWErLBUktHyCmyMm5fCL/uy+PujPXgo3PNWnhbCKi4V2vhDwgbW7D7J24O680iE9FNlfDaoXDVZ5wkTe44ovTxR+spvUC9pKndQSs1WSp1QSm0zuxZhnPzCIp5K2Mj3O0/w5kNhPNbbutv0WIkhQeWJTeWqyTqrT+w5ovTyxM97T/G3h+WVn5vNBfqbXYQwToGtiKcXb2D1juO8PrAbQ6JamV2SxzDqjGouHtZUrpqss/LEniPyCmw8sSCV/+05ydsPd+fRSHnl505a6zXAabPrsDpXrdVn9HELbUU8u2Qjq9KO88oDXRge3dqQOn2FIcMUWus1SqlQI47lLq6arBsYHsyylMOs3ffr75herYIYGB5sdxrQFTU4K6/Axu8XpvLT7pP87eHuPCYhZUlKqfHAeIBWrXzv1blZA1GOKrQV8dzSzXy19RiT7+/MqL7eu5mlqyht0N4tJUH1hda6WzmfL9tUEYcOHTLk+1pN8a6j124Z37ddAzYczrli0MLfT4GGgqJf/x8E+vu5/b0cZV0qtPH7Ban8sOskUweFMbiP7/0CrAqlVKrWOtLF3yOUCnqqrMjISJ2SkuLKciyn79TvybRz7ze4fiBrJ95piePaijTPL93Eyk1HePHeTjxxW7tq1+ULyusrtw1TaK3jtNaRWuvIxo29dzXg8nYdXbvv9DXTgAU2fUVIgbkTgpcKbTy5cAM/7DrJWxJSwuKsPhBlK9JM+HgzKzcdYUK/jhJSTvDZqT9XMWJ3UTMmBC8V2vjDwg2Xp5FiJKSExVl5IKqoSPPiii2s2JDJn++5kafuaO9UTb5OgspgRuwu6u4JweKR2Q18t/MEU2QayRKUUouBdUBHpVSGUmqs2TVZjVUHooqKNJNWbmVpSgbP3NWBZ+7q4FQ9wqBhipKmuh1opJTKAF7WWs8y4tieJiaqpdP3qNw5IZhfWMRTizawescJXh/YjWEyjWQJWusYs2uwuoHhwaQcOs3ipHRsWuOnFA9HOLaii7PLnV1Na81Ln21j8fp0nrqjHc/dLSFlBKOm/jyyqRxZ62/yyq1XNERMVEu7G7ZNGRhG0v4s9py4cPmxDk1qkxB7k93v50ijGb02YYGtiD8u2sC324/z2oCuVRuZfTMY8s9f+3jNOvBXa2zjLXzDyo2ZLE/NvHy53aY1y1MziWzdoEp9YdRyZ6W01rz6+XYWJh7midva8n+/7Ygy4AqL8OFLf45sQ106yVe2IRYmHmbyyq12n1s2pAD2nLjA5JVbGRgezNqJd3Jg6v2Xp4fsNZq9GozeNrs0pL7ZfpxXH+zKiJtCq/aF9kKqoseFcBFnlyszcrkzrTWvf7GDub8cZNwtbZjYv5OElIF8Nqgc+SEtb5LP3uOOPNeRGoxsqgJbEc8s/vXNhyNvDnX4GEKYzdnpPKOm+7TWvPX1TmavPcCom0OZdH9nCSmD+WxQOfJDWt4kn73HHXmuGbv+FpS8Q/7rbcd46Xdd5M2HwmM5O51nxHSf1pppq3YRt2Y/w6Nb8/IDXSSkXMBng8qRH9LyJvnsPe7Ic92962+hrYg/LdnEV1uP8f9+14Uxt0hICc/l7HSeEVOD767ew4c/7mNIVCtefbCrhJSL+GxQOfJDGhNlfwkhe4878lx37vpbaCvi2Y828eXWo0y+vzNjJaSEh3N2h15nv/6fq/fw/nd7eDyyJVMGdOO660wMqTeD4ZWga/+86R37XPnsxomOjKCWTvdVdervwMnzV6z117ddA6YMDLM7tffWoDCX7/pbutbYl1uOMum+zoy7tW3V/pHsqVmn/Kk/IdzM2Q1Gq/v103/Yy7urd/NIRAhvDQozN6TA64ecDFvrzxHevC7Z1SOvUHzm83BEMMtTM6953NXr+hXaivjz0s18tvkIf72vE+N/I8u4GMEda/05wpt7ympm/LSPqV/v5KHwYN55tAd+ZocUFJ89lfu5HPfV4STT1/rzFeVN5y1OSnf7zr+2Is3zy4pDauK9ElJCOGvm//Yz9eudPNCjBdMe6W6NkPIBElQGK28Kr7xpQFet61e6avOnm47wl/6d+L0siCmEU+asPcCUL3dwf1hz3n2sBzX85Nenu/jsPSpXaVE/0O4WAX5K2Q0rV6zrZyvSTFj266rNT94uISXMY/SKKmaYv+4gr36+nX5dm/Le4J6eFVJXXxb0wFVkvC6oHGkKV2xkOKFfRyZ8vJkC26+h5O+neLx3S7v3qIxe1690a4EVGzOZ0K9j9Vdt9sWlknzx7+xirtrc0J0Skg7x0qdp3N25Kf+K6YW/FUOqvCEnezxwwMKrgsqRprD33AnLNoPicshUu6muPnHSENm6AZGtG7j0laWtSPPCx8VbCzzv7NYCXj5FZJcv/p1drKIVVTwhqD5KPsykT7ZxZ6cmTB8aTs0aFgwpsP9CqqIBCw/jVUHlSFPYe+7VmxhW9PUV1XD1cQqKit+9vnbinS5rzqIizcTlW1i+IYPn7r6Rp2VrAWEBrtrc0B2WpaQzccVWbruxMR8O7cX1Nfwq/yLhEhZ9eVA9RixJ5Mhxna3BKEVFmokrtrAsNYM/3d2BZ2VrAWERrtrc0NU+2ZjBC8u3cEv7RvxneAQB/hJSZvKqoDJiSSJHjutsDUYo3kn0103a/nT3jS75PkJUh6s2N3SlTzdl8vzSzdzUtiFxwyMlpCzAqy79TejX0e6bbctbkujq5/pfp664R1XR1xtRg7OKijR//WQrH6Wk88yd7WWTNmE5zqyoYoYvthzhuY820Tu0ATNHRhJY04NDyohVZBwZMHLhMJJRO/z2B/4J+AEztdZTjTiuoxxpCnu7gz7ep2W5Aw/lTRMOjV93zXJJVV0WyRnF211vY0lyOn+8oz3P3XOjsQtieuJSSc42ioX+zlbpKSM4u8yRu/x321GeXbKJiNY3MHtUb2rV9PDX8UZMqjoyYOTCYSSn/08opfyA6cA9QAaQrJT6TGu93dljV0dVm6Ki3UFLNzUs+1x704TTf9hzzSaJpaF19TGMVFSkmfzpNhavP8xTd7Tj+d8aHFLgmePYzjaKRf7OVuspX/BN2jH+uGgjPUKCmDO6D7Wv9/CQ8jJG3KPqA+zVWu/XWucDS4ABBhzXpYzYtPDqkCpV9gzLaFprXvpsG4uSDvPk7e1ku2vv5JE95am+23GcpxZtoFtwEPPG9KGOhJTlGBFUwUDZ7WszSh6zNFdNCLqS1pqXPk1jYeJhfn9bO17oJyHlparUU0qp8UqpFKVUysmTJ91WnDf5cdcJnly4gc7N6zFvTB/qBvibXZKww21Tf1ZrKldNCLqK1pqXP0tjQeIhnvhNW/7SX0LK12mt47TWkVrryMaNG5tdjsf5356TjF+QSoemdVgwJoqgQAkpqzIiqDKBsrsChpQ8dgWrNZURmxZ2aFLb7rH7tmtgXKEUh9Srn29n/rpDjP9NWybe20lCyrtVqadE9f2y9xTjZv1MW9sBFmYNJuhvjbxus0FDlDdIZO9xR57rICMuxiYDHZRSbShupsHAEAOOW6HypvCqutafoxOCy1IOX3HvqVerIBJib+Kef/x4xb2qDk1q82hkK/pO/d6Qqb/SkJr7y0HG3dKGF6saUo4sn+LsqKm718gr7/t5D1N6ylck7s9izLxkQtUxEmq+yQ3qqp8l7/7Zcowj/evCYSSng0prXaiU+iOwiuJR2tla6zSnK6tAeVN4KYdOX7Hwa2Vr9VV1QnDyyq3XDEis3XeaofHryDiTd8XjB7MuMmHZ5svLKDmzCKfWmte+KA6psbe0YdL9nV1zJuXsqKm718jz8l8kZvSUr1h/4DRj5ibT8oZaJOS8SUN1zuySRBUYco9Ka/2V1vpGrXU7rfUbRhyzIu7enHBxUrrdx9fuO33teoE2fc1af9WpQWvNlC93MGftQUb3DWWyq0JKWJK7e8oXpB46zeg562kWFEBCbBSN1FmzSxJV5JFLKLl7c8LyjusIR2rQWvPGlzuY9fMBRt0cyku/6yIhJYQTNh4+w8jZyTSpF8Di2Gia1A0wuyThAI8MqvKm8PzK+WXu7NReecd1RFVr0Frz1tc7mVkSUi8/ICElhDM2p2czYtZ6GtapyeLYaJrWk5DyNB75zrby1tN7OCLYJZsTxkS1ZGHi4Wse79uuARsO51y5XqCfAn3lliFVrUFrzdSvdxK3Zj8jbmotISWEk7Zl5jB8VhL1a/uzODaaZkFlQsoqS2ZZeXDJIjwyqCqb2Cu7ft/DEc6vMzZlYNg1x42JasmUgWGG7RKstebt/+7iP2v2Mzy6Na8+2NXckHJ3E7uyAa28PqFwmbQjOQydmUTdgOKQuuaqhlV+sVt5cMkiPDKowP7EXkXr9xkRVqWBVVkdpY9Xldaav63axYyf9jEsuhWvDTA5pMD9TWxUA76S43wtwuPtPHaWYTOTqF3Tj8Wx0YTcUMvskoQTPPIeVXkcWb/PKrTWvPPNLv794z6GRLXitQe7mR9SQniwPcfPMTQ+ietr+LEoNppWDSWkPJ1XBZWnbXuttebv3+xm+g/7iOnTiikDunHddRJSQlTX3hPniYlPwu86xaLYKEIb2V89RngWrwoqT9v2+t3Ve/jgh70M7t2SNwZKSAnhjP0nzzMkPhGARbHRtG0s9ya9hcfeo7LHnbvrOuvdb3fz/nd7eDyyJW8+FCYhVRFHloO6mo9OSfmag6cuEBOfiK1Is2R8NO2beFBIlTe4BFX/2ffygSGvCipP2fb6vdW7+ed3e3g0IoS3BrkopCr64bciI+q116w+OiXlSw5nXSQmPpH8wiKWjL+JDk3rml2SY+y9YKoooHxwYMirggqsv+31+9/t4b3Ve3gkIoS3H+7uujMpR3/4zVbe2Y00rKhA+unikMotsLFoXDQdm3lYSIkq8ap7VFb3r+/28I9vd/NwLxeHlBA+IDM7lyEzEzmXV8DCsVF0aVHP7JKEi3jdGZVVTf9hL3//djeDwoP52yPd8ZOQEqLajubkMiQ+keyLBSSMi6JbsIWvFginyRmVG0z/YS/TVu3iofBgpj3aQ0JKCCccP5vHkPgkss7nM39MH7qH1De7JOFickblYv/+cR/TVu1iYM8WvHN1SLlyIs3RzQXt3QtyZK0xe4yYRHJ2KSerrOcmDHHiXB4x8YmcOJvH/LF9CG91g9kl2edsb8vP7RUkqFxoxk/7ePu/OxnQswV/f6zntWdSrpxIq+gYVw8hlDew4MhaY/aOawRnA1tG0L3GqfOXGBKfxLGcPOaN6UNE6wZml1Q+Z3tbfm6vIJf+XCRuzT6mfr2TB3q04O9yuU8Ip2Sdv8TQ+CQyzlxk9qje9A61cEgJwzkVVEqpR5VSaUqpIqVUpFFFebr4NTT1FREAAA60SURBVPt586ud/K57c959rAc1/OT1gKga6alrnbmQz9CZSRzMusDskb2JbtvQ7JKEmzn7G3QbMAhYY0AtXmHm//bzxlc7uL97c957vKeElHCU9FQZ2ReLQ2r/qQvMHBnJze0bmV2SMIFT96i01jsAWe27xKyfDzDlyx3cH9acf0pIiWqQnvpVTm4Bw2etZ++J88SNiODWDo3NLkmYRIYpDDL75wO8/sV27u3WjPcGVzGkjJjscXS6D5xbP0ymkSxDKTUeGA/QqlUrk6sx1tm8AkbMXs/OY2f5z/AIbu/YxOyShIkqDSql1GqgmZ1PTdJaf1rVb+TNTTV37QFe+2I7/bs24/2YcPyreiZlxGSPEdN99p5bHplGcppRPaW1jgPiACIjI7VB5Znu/KVCRs1eT1pmDv8eFsGdnZqaXZIwWaVBpbW+24hv5K1NNe+Xg7zy+Xb6dW3Kv4Y4EFLCZxnVU97owqVCRs9Zz+aMHKYP6cU9XSSkhIynO2X+uoO8/Fkav+3SlH/F9JKQEsIJF/MLGT03mQ2Hs3l/cDj9u9k76RS+yNnx9IeUUhnATcCXSqlVxpRlfQsSD/HSp2nc06UpHwzpRc0aElLCeb7aU7n5NsbOTSHl4Gn+8VgP7u/e3OyShIU4O/X3CfCJQbV4jIWJh/h/K7dxd+cmTJeQEgbyxZ7KK7AROz+FxANZ/OOxHgzoadFtehxZFkmGjgwlU38OWpR0mMkrt3FXpyZMH+pASFlhp1krNI8V/h2EZeQV2HhiQSpr951i2iM9eCg8xOySyufIskjys2woCSoHLF5/mL9+spU7OzXhw2G9uL6GX9W/2Ao7zVqheazw7yAs4VKhjT8kbOCn3Sd5++EwHomwcEgJU8k1qypasv4wL67Yyh0dG/NvR0NKCHGF/MIinkrYyPc7T/DmQ2E83tu73rIijCVBVQVLk9OZuGIrt3dszL+HRUhICeGEAlsRTy/ewOodx3l9QFeGRElIiYpJUFViaUo6f1mxhd/c2JgZwyII8JeQEqK6Cm1FPLtkI6vSjvPyA10YflOo2SUJDyD3qCqwLCWdvyzfwi3tGxE33E5IyWCAEFVWaCviuaWb+WrrMSbf35nRfduYXZJ9VV1irJRM8rmcBFU5Pk7N4IWSkIofEWn/TMqRwQBXTdxZYZLPEZ5WrzCErUjzf8s28/nmI7x4byfG3drW7JKqxxWbg4pKSVDZsWJDBhM+3kzfdhWElKNcdYblaWdunlavcFpRkeaFj7ewctMRJvTryBO3tTO7JOFh5B7VVT7ZmMHzyzZzc7uGxoWUED6qqEjz4oqtLN+QwZ/vuZGn7mhvdknCA0lQlbFyYybPL93MTW0bMnNEbwJrSkgJUV1FRZpJK7fxUUo6z9zVgWfu6mB2ScJDSVCV+HRTJn9euomoNg2ZNVJCSghnaK15+bM0Fq8/zFN3tOO5uyWkRPX5xD2qlRszmbZqF0eyc2lRP5AJ/ToyMPzX9cQ+23yE5z7aRJ82DZg1KrLqIWWFwQBPmzz0tHqFw7TWvPr5dhYkHuKJ37Tl/37b0bkdi+Vnxud5fVCt3JjJiyu2kltgAyAzO5cXV2wFYGB4MJ9vPsKflmwkMrQBs0f1plZNB/5JrNAknrYkkafVKxyitWbKlzuY+8tBxt7Shon3dnIupMD9PzMy2Wc5Xh9U01btuhxSpXILbExbtYsafoo/fbSJyNYNmONoSAkhrqC1ZurXO5n18wFG3RzK5Ps7Ox9SQuAD96iOZOfafTwzO5dnl2yiV6v6zBndm9rXS0gJUV1aa6at2sV/1uxnWHQrXn6gi4SUMIzXB1WL+oHlfi68ZX3mjO4jISWEk95dvYcPf9xHTJ9WvPZgNwkpYSivD6oJ/ToSaOe9UG0a1mbumD7UkZASwin/XL2H97/bw2ORIbwxsBvXXSchJYzl1G9ppdQ04AEgH9gHjNZaZxtRmFFKp/umrdpFZsllwNCGtfjs6b7eEVJWmDwsT3nTWvZYoV4L8ISeKmv6D3t5d/VuHu4VwtRB3V0TUlb+GTeCTDVWytnf1N8CL2qtC5VSbwMvAn9xvixjDQwPplZNP/6QsIGwkCDmj+lD3QB/s8syhpV/kCsKKZmsKo9H9BTAjJ/2MW3VLgb2bMHfHnFRSIG1f8aNIJOwlXLq0p/W+hutdWHJh4mAJbfo/Hb7cZ5atIFuwUHM86aQEl7HU3pq5v/2M/XrnTzQowXvPNoDP7ncJ1zIyHtUY4Cvy/ukUmq8UipFKZVy8uRJA79txVZvP84fElLp0iKI+WP7UE9CSngOS/bUnLUHmPLlDu4Pa867j/Wghp/X3+oWJqv00p9SajXQzM6nJmmtPy15ziSgEEgo7zha6zggDiAyMlJXq1oHfbfjOE8mpNKleT3mj5GQEtbgyT21YN1BXv18O/26NuW9wT0lpIRbVBpUWuu7K/q8UmoU8DvgLq21W5qlKr7feZwnF26gc/N6zB8bRVCghJSwBk/tqUVJh/l/n6Zxd+em/CumF/4SUsJNnPpJU0r1B14AHtRaXzSmJOf9sPMEv1+wgY7N6rJgjISUacqbyvKWaS0XsGpPLU1O56+fbOXOTk2YPjScmjUkpAwjfVIpZ6f+PgCuB74teYNfotb6905X5YQfd53giQWp3NisDgvHRhFUS0LKNN4+reUaluupj1Mz+MuKLdx2Y2M+HNqL62vIzgKGkj6plFNBpbW21C5oP+0+yfgFqXRoKiElPJPVeuqTjb/udv2f4RGykagwhdecv6/ZfZLY+Sm0b1yHhHFR1K9V0+yShPBon20+wvNLNxPdRna7FubyiqD6357ikGonISWEIb7ccpTnPtpEZKiDe7QJ4QIev4bQz3tOMW5eCm0a1SZhXBQ31DY5pGQ5FOHh/rvtKM8s2Vi8aLNsfyMswKPPqNbuPcXYecm0aVSbRbHRNDA7pECWQxEe7Zu0Y/xx0UZ6hAQxd4zsLCCswWOD6pcyIZUwLsoaISWEB/tuR/FSY12Dg2RnAWEpHhlUv+w7xZh5ybRuUBxSDetcb3ZJQni0H3ed4MmFG+jUTFZxEdbjcUG1bl8WY+Ym06pBLRJiJaSEcNb/9vz6to4FY/vIG+SF5XhUUCXuLw6pljfUYlFsNI0kpIRwyi97i4eR2jaqzcKxMjErrMljgippfxaj5yQTfEOgtUNKlkMRHiJxfxZj5iUT2tAiE7NClMMj7pauP3Ca0XOTaVE/gEWxUTSua9GQAhlBFx4h+eDpy1cn5BK6sDrLn1ElHzzNqDnraRYUwOLYaJrUDTC7JCE8WuqhM4yaXdxTCbFR1r06IUQJSwdVysHTxQ1VL4AlsdE0qSchJYQzNh4+w8jZ62lST174Cc9h2aBKPXSakbPX07ReAIvHS0gJ4awtGdmMmL2ehnVqsjg2mqbSU8JDWDKoUg+dYeTs5OJXfeOloYRw1rbMHIbNTCIo0J9FsdE0C5KeEp7DckG1oeTSRCN51SeEIbYfOcuwWUnUDfBncWw0wfUDzS5JCIdYKqi01rz+xfbiSxPj5VWfEEZ46+sdBPr7sTg2mpYNapldjhAOc2o8XSn1OjAAKAJOAKO01kecOB7/GR5BoU3TPEhe9QlhhH/FhHM2t5BWDSWkhGdy9oxqmta6u9a6J/AF8JKzBTWpG0ALuTQhfJRS6nWl1Bal1Cal1DdKqRbOHrN+rZoSUsKjORVUWuuzZT6sDWjnyhHC5xn+4k8IT+f0yhRKqTeAEUAOcIfTFQnhw+TFnxDXqvSMSim1Wim1zc6fAQBa60la65ZAAvDHCo4zXimVopRKOXnypHF/AyG8jFLqDaVUOjCUCs6opKeEr1BaG/OCTSnVCvhKa92tsudGRkbqlJQUQ76vEGZQSqVqrSOr+bWrgWZ2PjVJa/1pmee9CARorV+u7JjSU8IblNdXzk79ddBa7yn5cACw05njCeELtNZ3V/GpCcBXQKVBJYQ3c+qMSim1HOhI8Xj6IeD3WutKlw9XSp0seX55GgGnql2Y8axWD1ivJqvVA66tqbXWurHRBy374k8p9TRwm9b6kSp8naf1FFivJqmncq6uyW5fGXbpz0hKqZTqXlZxBavVA9aryWr1gDVrqkx1X/xV4biW+7ewWk1ST+XMqskj9qMSwldorR82uwYhrMZSSygJIYQQV7NqUMWZXcBVrFYPWK8mq9UD1qzJLFb8t7BaTVJP5UypyZL3qIQQQohSVj2jEkIIIQALB5VSappSamfJAp2fKKXqm1zPo0qpNKVUkVLKtEkcpVR/pdQupdRepdREs+ooU89spdQJpdQ2s2sBUEq1VEr9oJTaXvL/61mza7IK6akKa7FMX1mtp8D8vrJsUAHfAt201t2B3cCLJtezDRgErDGrAKWUHzAduBfoAsQopbqYVU+JuUB/k2soqxB4XmvdBYgGnrLAv5FVSE/ZYcG+mou1egpM7ivLBpXW+hutdWHJh4lAiMn17NBa7zKzBqAPsFdrvV9rnQ8soXhFENNordcAp82soSyt9VGt9YaS/z4H7ACCza3KGqSnymWpvrJaT4H5fWXZoLrKGOBrs4uwgGAgvczHGcgv4XIppUKBcCDJ3EosSXrqV9JXDjCjr0x9w29VFudUSk2i+LQzwQr1CM+glKoDLAf+dNXWGV5Nekq4kll9ZWpQVbY4p1JqFPA74C7thjl6BxYLNUsm0LLMxyElj4kylFL+FDdTgtZ6hdn1uJP0VLVIX1WBmX1l2Ut/Sqn+wAvAg1rri2bXYxHJQAelVBulVE1gMPCZyTVZilJKAbOAHVrrf5hdj5VIT5VL+qoSZveVZYMK+ACoC3yrlNqklJphZjFKqYeUUhnATcCXSqlV7q6h5Eb4H4FVFN/MXKq1TnN3HWUppRYD64COSqkMpdRYM+sB+gLDgTtLfm42KaXuM7kmq5CessNqfWXBngKT+0pWphBCCGFpVj6jEkIIISSohBBCWJsElRBCCEuToBJCCGFpElRCCCEsTYJKCCGEpUlQCSGEsDQJKiGEEJb2/wF5x+N/Q0V3uwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 504x216 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XaFMa20IAsMk",
        "colab_type": "text"
      },
      "source": [
        "## High-level implementation using the nn.Module API"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k8Eqg0MKAsMl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class LogisticRegression3(torch.nn.Module):\n",
        "\n",
        "    def __init__(self, num_features):\n",
        "        super(LogisticRegression3, self).__init__()\n",
        "        self.linear = torch.nn.Linear(num_features, 1)\n",
        "        # initialize weights to zeros here,\n",
        "        # since we used zero weights in the\n",
        "        # manual approach\n",
        "        \n",
        "        self.linear.weight.detach().zero_()\n",
        "        self.linear.bias.detach().zero_()\n",
        "        # Note: the trailing underscore\n",
        "        # means \"in-place operation\" in the context\n",
        "        # of PyTorch\n",
        "        \n",
        "    def forward(self, x):\n",
        "        logits = self.linear(x)\n",
        "        probas = torch.sigmoid(logits)\n",
        "        return probas\n",
        "\n",
        "model = LogisticRegression3(num_features=2).to(device)"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ATsNJ2CxAsMp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 165
        },
        "outputId": "3a874eff-d1b8-4386-e205-73d5c07438be"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([75, 2]))"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<hiddenlayer.graph.Graph at 0x7f79b88078d0>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"310pt\" height=\"108pt\"\n viewBox=\"0.00 0.00 310.00 108.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 72)\">\n<title>%3</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-72,36 -72,-72 238,-72 238,36 -72,36\"/>\n<!-- /outputs/3 -->\n<g id=\"node1\" class=\"node\">\n<title>/outputs/3</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-36 0,-36 0,0 54,0 54,-36\"/>\n<text text-anchor=\"start\" x=\"14\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n</g>\n<!-- /outputs/4 -->\n<g id=\"node2\" class=\"node\">\n<title>/outputs/4</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"166,-36 112,-36 112,0 166,0 166,-36\"/>\n<text text-anchor=\"start\" x=\"122\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Sigmoid</text>\n</g>\n<!-- /outputs/3&#45;&gt;/outputs/4 -->\n<g id=\"edge1\" class=\"edge\">\n<title>/outputs/3&#45;&gt;/outputs/4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M54.1121,-18C68.3143,-18 85.9303,-18 101.4272,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"101.7494,-21.5001 111.7494,-18 101.7494,-14.5001 101.7494,-21.5001\"/>\n<text text-anchor=\"middle\" x=\"83\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">75x1</text>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IaOjEACzAsMs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##### Define cost function and set up optimizer #####\n",
        "cost_fn = torch.nn.BCELoss(reduction='sum')\n",
        "# average_size=False to match results in\n",
        "# manual approach, where we did not normalize\n",
        "# the cost by the batch size\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)"
      ],
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "tags": [],
        "id": "AFm9tlSlAsMw",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 311
        },
        "outputId": "94f88c5e-2bf0-4744-dc52-92945c3e1e17"
      },
      "source": [
        "def comp_accuracy(label_var, pred_probas):\n",
        "    pred_labels = custom_where((pred_probas > 0.5).float(), 1, 0).view(-1)\n",
        "    acc = torch.sum(pred_labels == label_var.view(-1)).float() / label_var.size(0)\n",
        "    return acc\n",
        "\n",
        "\n",
        "num_epochs = 10\n",
        "\n",
        "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n",
        "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device).view(-1, 1)\n",
        "print(X_train_tensor.shape)\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    \n",
        "    #### Compute outputs ####\n",
        "    out = model(X_train_tensor)\n",
        "    \n",
        "    #### Compute gradients ####\n",
        "    cost = cost_fn(out, y_train_tensor)\n",
        "    optimizer.zero_grad()\n",
        "    cost.backward()\n",
        "    \n",
        "    #### Update weights ####  \n",
        "    optimizer.step()\n",
        "    \n",
        "    #### Logging ####      \n",
        "    pred_probas = model(X_train_tensor)\n",
        "    acc = comp_accuracy(y_train_tensor, pred_probas)\n",
        "    print('Epoch: %03d' % (epoch + 1), end=\"\")\n",
        "    print(' | Train ACC: %.3f' % acc, end=\"\")\n",
        "    print(' | Cost: %.3f' % cost_fn(pred_probas, y_train_tensor))\n",
        "\n",
        "\n",
        "    \n",
        "print('\\nModel parameters:')\n",
        "print('  Weights: %s' % model.linear.weight)\n",
        "print('  Bias: %s' % model.linear.bias)"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([75, 2])\n",
            "Epoch: 001 | Train ACC: 0.987 | Cost: 5.581\n",
            "Epoch: 002 | Train ACC: 0.987 | Cost: 4.882\n",
            "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n",
            "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n",
            "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n",
            "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n",
            "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n",
            "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n",
            "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n",
            "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n",
            "\n",
            "Model parameters:\n",
            "  Weights: Parameter containing:\n",
            "tensor([[ 4.2267, -2.9613]], requires_grad=True)\n",
            "  Bias: Parameter containing:\n",
            "tensor([0.0994], requires_grad=True)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1tfXeJUTAsM3",
        "colab_type": "text"
      },
      "source": [
        "#### Evaluating the Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "SbAwApiGAsM5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "2c6ab791-2856-47ce-9e79-cdfa46a04505"
      },
      "source": [
        "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n",
        "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n",
        "\n",
        "pred_probas = model(X_test_tensor)\n",
        "test_acc = comp_accuracy(y_test_tensor, pred_probas)\n",
        "\n",
        "print('Test set accuracy: %.2f%%' % (test_acc*100))"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test set accuracy: 100.00%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oA7cdH92AsNA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 211
        },
        "outputId": "c2c2355b-d0a6-4df4-e4b7-7df6bc4706b9"
      },
      "source": [
        "##########################\n",
        "### 2D Decision Boundary\n",
        "##########################\n",
        "\n",
        "w, b = logr.weights, logr.bias\n",
        "\n",
        "x_min = -2\n",
        "y_min = ( (-(w[0] * x_min) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "x_max = 2\n",
        "y_max = ( (-(w[0] * x_max) - b[0]) \n",
        "          / w[1] )\n",
        "\n",
        "\n",
        "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n",
        "ax[0].plot([x_min, x_max], [y_min, y_max])\n",
        "ax[1].plot([x_min, x_max], [y_min, y_max])\n",
        "\n",
        "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n",
        "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n",
        "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n",
        "\n",
        "ax[1].legend(loc='upper left')\n",
        "plt.show()"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaoAAADCCAYAAAAYX4Z1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVyVdfr/8dcnxMANc1dQccsVFUGhrGkfraY028R9w6Zpqmn62tjor9XKxmZqmmwccFfUNM32sWxzMkHAHfcdcEfBDQQOn98fgKEelsO5z7nvc871fDx8POJwuLk0Lt7nvu/rfD5Ka40QQghhVdeZXYAQQghREQkqIYQQliZBJYQQwtIkqIQQQliaBJUQQghLk6ASQghhaTXM+KaNGjXSoaGhZnxrIQyRmpp6Smvd2Ow6SklPCW9QXl+ZElShoaGkpKSY8a2FMIRS6pALjx0ArAGup7hHP9Zav1zR10hPCW9QXl85HVTVaSohRIUuAXdqrc8rpfyBn5VSX2utE80uTAgzGHFGJU0lhIF08XIx50s+9C/5I0vICJ/ldFBJU3mWlRszmbZqF0eyc2lRP5AJ/ToyMDzY7LLEVZRSfkAq0B6YrrVOsvOc8cB4gFatWrm3QCHcyJB7VFVpKmG+lRszeXHFVnILbABkZufy4oqtABJWFqO1tgE9lVL1gU+UUt201tuuek4cEAcQGRl5zYvDgoICMjIyyMvLc0vNniIgIICQkBD8/f3NLkVUkSFBVZWmkld/5pu2atflkCqVW2Bj2qpdElQWpbXOVkr9APQHtlX2/LIyMjKoW7cuoaGhKKVcU6CH0VqTlZVFRkYGbdq0MbscjzLjp33Uub4Gw6Jbu/17G/o+Kq11NlDaVFd/Lk5rHam1jmzc2DJTvT7lSHauQ497ssNZF8nJLTC7jGpRSjUuedGHUioQuAfY6ehx8vLyaNiwoYRUGUopGjZsKGeZDopfs5+pX+8k+eBpzNhxw+mgMqqphOu1qB/o0OOe6lDWBR6PW8ezSzaaXUp1NQd+UEptAZKBb7XWX1TnQBJS15J/E8fM/vkAb3y1g/vDmvP3R3uY8u9nxBmVYU0lXGtCv44E+vtd8Vigvx8T+nU0qSLjHc66SExcInkFNv7Sv5PZ5VSL1nqL1jpca91da91Na/2a2TUZ6ZVXXuGdd95xybFTU1MJCwujffv2PPPMM6a8+vcm89cd5LUvttO/azPeG9yTGn7mLGbk9Hf19qbyJgPDg3lrUBjB9QNRQHD9QN4aFOY196cOZ11kcNw6cgtsJIyLpnPzemaXJNzsySefJD4+nj179rBnzx7++9//ml2Sx0pIOsRLn6ZxT5emvB8Tjr9JIQUmrUwhzDMwPNhrgqms9NMXiYlP5GKBjYRxUXRpISHlKFe8dWH+/Pm88847KKXo3r07CxYsuOLz8fHxxMXFkZ+fT/v27VmwYAG1atVi2bJlvPrqq/j5+REUFMSaNWtIS0tj9OjR5OfnU1RUxPLly+nQocPlYx09epSzZ88SHR0NwIgRI1i5ciX33nuvU38HX/RR8mEmfbKNuzo1YfqQXtSsYe6ysBJUwuOln77I4LhEzl8qJGFcFF1bBJldksdxxVsX0tLSmDJlCr/88guNGjXi9OnT1zxn0KBBxMbGAjB58mRmzZrF008/zWuvvcaqVasIDg4mOzsbgBkzZvDss88ydOhQ8vPzsdmunGDNzMwkJCTk8schISFkZmZWq3ZftiwlnYkrtnJ7x8Z8OMz8kAJZPV14uIwzxWdSpSHVLVhCqjoqeutCdX3//fc8+uijNGrUCIAGDRpc85xt27Zx6623EhYWRkJCAmlpaQD07duXUaNGER8ffzmQbrrpJt58803efvttDh06RGCgdw0BWcEnGzN4YfkWbmnfiBnDIri+hl/lX+QGElTCY2Vm5zI4LpGzuQUSUk4y660Lo0aN4oMPPmDr1q28/PLLl8fGZ8yYwZQpU0hPTyciIoKsrCyGDBnCZ599RmBgIPfddx/ff//9FccKDg4mIyPj8scZGRkEB1fvbHDlxkz6Tv2eNhO/pO/U71m50fvPzD7dlMnzSzdzU9uGxI+IJMDfGiEFElTCQxWH1LqSkIqWkHKSK966cOedd7Js2TKysrIA7F76O3fuHM2bN6egoICEhITLj+/bt4+oqChee+01GjduTHp6Ovv376dt27Y888wzDBgwgC1btlxxrObNm1OvXj0SExPRWjN//nwGDBjgcN2ll0Ezs3PR/HoZ1JvD6ostR3juo030Dm3AzJHWCimQoBIe6Eh2LjFxiWRfLGDhuCjCQiSknOWKty507dqVSZMmcdttt9GjRw/+/Oc/X/Oc119/naioKPr27UunTr++nWDChAmEhYXRrVs3br75Znr06MHSpUvp1q0bPXv2ZNu2bYwYMeKa43344YeMGzeO9u3b065du2oNUrjiMqiVfb31KM8u2URE6xuYPao3tWpab3RBmfE+g8jISC1754jqOJpTfLnv9Pl8Fo6LokfL+qbUoZRK1VpHmvLN7bDXUzt27KBz585VPoYvLVhc0b9Nm4lf2l1VWwEHpt7v0rrc7Zu0Y/whYQPdQ4KYPzaKOtebG1Ll9ZX1olOIchzLybscUgtMDClv5a1vXXBUi/qBZNq5N+dtK7h8t+M4Ty3aQLfgIOaN6WN6SFVELv0Jj1AcUuvIOp/P/LF96CkhJVzEF1Zw+WHXCZ5cuIHOzesxb0wf6gZYeyV560aoECWOn80jJj6RUyUhFd7qBrNLEhZ25mI+x3PyyLcVUdPvOpoGBXBDrZpV/vrSs0pvvQy6ZvdJnliQSoemdVgwJoqgQGuHFEhQCYs7fjaPmLhETpzNY/7YKHpJSIkKnLmYT+aZXIpK7r3n24rIPFN8Gc/RsPKWYCpr7d5TxM5PoV3jOiwcG0VQLeuHFMilP2FhJ0rOpI6fzWP+2D5EtJaQEhU7npN3OaRKFWnN8RzZ1mPdvizGzkumTaPaJIyL4obaVQ9us8kZlbCkE+eKQ+pYTh7zx/QhovW1qxoIcbV8W5FDj/uK9QdOM2ZuMi1vqMXCcVE08KCQAjmjEhZ08twlYuISOZqTx7wxfYgMlZDyJq7c5uPDaVP4bZ+uRHcMueLxmiau/G221EOnGT1nPS3qB7AoNppGda43uySH+e7/PWFJJ89dIia+OKTmju5Dbwkp4YBBDw1g8RdXLq10nVI0DQowqSJzbTx8hpGzk2laL4DFsdE0rut5IQVy6U9YyKnzlxgSn0jmmVzmju5NnzYSUm7zZjDkn7/28Zp14K/VXzrIndt8APz29ls5czGf0j1oqzP15y02p2czYtZ6GtapyaLYaJrU89ywlqASgPmrEpSGVMaZXOaM7k1U24Zu+95Wo5RqCcwHmgIaiNNa/9Ol39ReSFX0eBW4e5uPUjfUqolS0D3Ed99rty0zh+Gzkqhf25/FsdE08/AzSqcv/SmlWiqlflBKbVdKpSmlnjWiMOE+Zi/CmXX+EkPjkzh8+iKzRkUS7cMhVaIQeF5r3QWIBp5SSnUxuSaHyTYf5kg7ksPQmUnUDSgOKW9YUcOIe1Re0VS+zMxFOLPOX2LozCQOnb7A7JG9ubldI5d/T6vTWh/VWm8o+e9zwA7A+97Ug7HbfAjYeewsw2YmUbumH4tjowm5oZbZJRnC6aDypabyVmbtRXT6Qj5DZyZx4NQFZo3szc3tJaSuppQKBcKBJDufG6+USlFKpZw8edLdpVXK3dt8+Lrdx88xND6J62v4sSg2mlYNvSOkwOCpP09uKl/mir2IKnPmqpDqKyF1DaVUHWA58Cet9dmrP6+1jtNaR2qtIxs3buz+AithxjYfL7zwAiEhIVy8eJGQkBBeeeUVV/4VLWPviXMMiU/E7zrFotgoQhvVNrskQxm2zUdJU/0EvKG1XlHRc2WbD2spvUdV9vJfoL8fbw0Kc8lARWlI7T15nlkjI7m1g/V+yVbG1dt8KKX8gS+AVVrrf1T2fKe3+XDR1J9VOboFipXtO3mewXGJaA1LxkfTvkkds0uqNpdu81HSVMuBhMpCSlTf5JVbWZyUjk1r/JQiJqolUwaGOX1cdy7CmX0xn2GzikNq5gjPDClXU0opYBawoyohZQgvDCNfcPDUBYbEJ1JUpD0+pCridFCZ0lQ+aPLKrSxMPHz5Y5vWlz82KqxcPY5eGlJ7TpwnfkQkv7lRQqocfYHhwFal1KaSx/6qtf7KxJqExRzOukhMfCIFNs3i2Gg6NK1rdkkuY8Q9qtKmulMptankz30GHFeUsTgp3aHHrSbnYgHDZ61n97HzxA2P4DYJqXJprX/WWiutdXetdc+SPxJS4rL008UhlVtgY+HYKDo2896QAgPOqLTWP8PlN4ILF7GVcy+xvMetJCe3gOGzk9h17Bz/GR7B7R2bmF2Sz9BaU3zRQ5Qy6r68WTKzc4mJT+RcXgGLYqPp0qKe2SW5nKz15yH8yvllU97jVpGTW8CIWUnsPHqOGcN7cUcnCSl3CQgIICsry+N/MRtJa01WVhYBAZ65UsPRnFxi4hLJyS1g4bgougUHmV2SW8gSSh4iJqrlFfeoyj5uVWfzChgxez3bj55lxrAI7uzU1OySfEpISAgZGRl489tBsi/mc+GSDU3xZZ3a1/tRv5J1/QICAggJCanwOVZ0/GweQ+KTOHMhnwXjonxqiSgJKjdyZj290oEJZ6f+3LWm39m8AkbMWs/2Izn8e2gEd3WWkHI3f39/2rRpY3YZLlM8YJRxzePDolsZMmBkJaX7s5XudN2zpe+EFEhQuc3V71UqXU8PcCisnGlAI2qoinN5BYycvZ60Izl8ODSCu7tISAnjVTRg5E1BVbxgcxLHSvZn88WdruUelZuYuZ6eO2soDamtGTlMH9KLeySkhIt48oBRVZUu2Jx5Jpc5o3r77P5sElRuYtZ6eu6s4fylQkbNSWZLRg4fDOnFb7s2M+S4QtjjqQNGVVW6gsvBrAvMGhnp01vfSFC5iRnr6bmzhvOXChk1ez2b07P5YEg4/btJSAnXKm+QyJEBo5UbM+k79XvaTPySvlO/d9vWNpXJvlgcUvtPXWDmyEifX7BZgspNJvTriP91V77S879OMaFfR7vPd0UDTejXkUB/vyseC/T3K7eGqjp/qZDRc9azMT2bf8WE079bc6eOJ0RVTBkYxrDoVpfPoPyUcmiQwux92MqTk1v85vi9J4rfHC/LjMkwhXtdfUWinCsUrhp6cMWafhcuFTJmTjIbDheH1L1hElLCfZwZMKronq07d7cuq/QtHTuPnZU3x5chQeUm01btosB25U3eApu22xSubCAj1/S7cKmQ0XOTST18hvcHh3OfhJTwIFa4b1zW5WnZzBz+Le87vIJc+nMTR5rCag1kz8X8QsbMTSb10Bnee7wn93eXkBKexQr3jUtduFTI6DKDSDIteyUJKjdxpCms1ED2lIZU8sHTvPt4Tx7o0cLskoRwmKvu2TrqYn7xlYmN6dm8P1gGkeyRoHITR5rCKg1kT26+jbFzU1h/oDikHpSQEh5qYHgwbw0KI7h+IAoIrh/oss1Cy1PaTyklL/rkyoR9Pn2PylXLCVV03Kp8P3duZOiI3HwbY+clk3Qgi3cf78mAnubWI4Sz3LEPW3nyCmzEzk8h8UAW7z4mL/oq4rNB5arJusqOW9Vjm9lA9uQV2Bg3P5l1+7P4x2M9JKSEcEJegY3xC1JZu+8U0x7pYaletyKfvfTnquWErLBUktHyCmyMm5fCL/uy+PujPXgo3PNWnhbCKi4V2vhDwgbW7D7J24O680iE9FNlfDaoXDVZ5wkTe44ovTxR+spvUC9pKndQSs1WSp1QSm0zuxZhnPzCIp5K2Mj3O0/w5kNhPNbbutv0WIkhQeWJTeWqyTqrT+w5ovTyxM97T/G3h+WVn5vNBfqbXYQwToGtiKcXb2D1juO8PrAbQ6JamV2SxzDqjGouHtZUrpqss/LEniPyCmw8sSCV/+05ydsPd+fRSHnl505a6zXAabPrsDpXrdVn9HELbUU8u2Qjq9KO88oDXRge3dqQOn2FIcMUWus1SqlQI47lLq6arBsYHsyylMOs3ffr75herYIYGB5sdxrQFTU4K6/Axu8XpvLT7pP87eHuPCYhZUlKqfHAeIBWrXzv1blZA1GOKrQV8dzSzXy19RiT7+/MqL7eu5mlqyht0N4tJUH1hda6WzmfL9tUEYcOHTLk+1pN8a6j124Z37ddAzYczrli0MLfT4GGgqJf/x8E+vu5/b0cZV0qtPH7Ban8sOskUweFMbiP7/0CrAqlVKrWOtLF3yOUCnqqrMjISJ2SkuLKciyn79TvybRz7ze4fiBrJ95piePaijTPL93Eyk1HePHeTjxxW7tq1+ULyusrtw1TaK3jtNaRWuvIxo29dzXg8nYdXbvv9DXTgAU2fUVIgbkTgpcKbTy5cAM/7DrJWxJSwuKsPhBlK9JM+HgzKzcdYUK/jhJSTvDZqT9XMWJ3UTMmBC8V2vjDwg2Xp5FiJKSExVl5IKqoSPPiii2s2JDJn++5kafuaO9UTb5OgspgRuwu6u4JweKR2Q18t/MEU2QayRKUUouBdUBHpVSGUmqs2TVZjVUHooqKNJNWbmVpSgbP3NWBZ+7q4FQ9wqBhipKmuh1opJTKAF7WWs8y4tieJiaqpdP3qNw5IZhfWMRTizawescJXh/YjWEyjWQJWusYs2uwuoHhwaQcOs3ipHRsWuOnFA9HOLaii7PLnV1Na81Ln21j8fp0nrqjHc/dLSFlBKOm/jyyqRxZ62/yyq1XNERMVEu7G7ZNGRhG0v4s9py4cPmxDk1qkxB7k93v50ijGb02YYGtiD8u2sC324/z2oCuVRuZfTMY8s9f+3jNOvBXa2zjLXzDyo2ZLE/NvHy53aY1y1MziWzdoEp9YdRyZ6W01rz6+XYWJh7midva8n+/7Ygy4AqL8OFLf45sQ106yVe2IRYmHmbyyq12n1s2pAD2nLjA5JVbGRgezNqJd3Jg6v2Xp4fsNZq9GozeNrs0pL7ZfpxXH+zKiJtCq/aF9kKqoseFcBFnlyszcrkzrTWvf7GDub8cZNwtbZjYv5OElIF8Nqgc+SEtb5LP3uOOPNeRGoxsqgJbEc8s/vXNhyNvDnX4GEKYzdnpPKOm+7TWvPX1TmavPcCom0OZdH9nCSmD+WxQOfJDWt4kn73HHXmuGbv+FpS8Q/7rbcd46Xdd5M2HwmM5O51nxHSf1pppq3YRt2Y/w6Nb8/IDXSSkXMBng8qRH9LyJvnsPe7Ic92962+hrYg/LdnEV1uP8f9+14Uxt0hICc/l7HSeEVOD767ew4c/7mNIVCtefbCrhJSL+GxQOfJDGhNlfwkhe4878lx37vpbaCvi2Y828eXWo0y+vzNjJaSEh3N2h15nv/6fq/fw/nd7eDyyJVMGdOO660wMqTeD4ZWga/+86R37XPnsxomOjKCWTvdVdervwMnzV6z117ddA6YMDLM7tffWoDCX7/pbutbYl1uOMum+zoy7tW3V/pHsqVmn/Kk/IdzM2Q1Gq/v103/Yy7urd/NIRAhvDQozN6TA64ecDFvrzxHevC7Z1SOvUHzm83BEMMtTM6953NXr+hXaivjz0s18tvkIf72vE+N/I8u4GMEda/05wpt7ympm/LSPqV/v5KHwYN55tAd+ZocUFJ89lfu5HPfV4STT1/rzFeVN5y1OSnf7zr+2Is3zy4pDauK9ElJCOGvm//Yz9eudPNCjBdMe6W6NkPIBElQGK28Kr7xpQFet61e6avOnm47wl/6d+L0siCmEU+asPcCUL3dwf1hz3n2sBzX85Nenu/jsPSpXaVE/0O4WAX5K2Q0rV6zrZyvSTFj266rNT94uISXMY/SKKmaYv+4gr36+nX5dm/Le4J6eFVJXXxb0wFVkvC6oHGkKV2xkOKFfRyZ8vJkC26+h5O+neLx3S7v3qIxe1690a4EVGzOZ0K9j9Vdt9sWlknzx7+xirtrc0J0Skg7x0qdp3N25Kf+K6YW/FUOqvCEnezxwwMKrgsqRprD33AnLNoPicshUu6muPnHSENm6AZGtG7j0laWtSPPCx8VbCzzv7NYCXj5FZJcv/p1drKIVVTwhqD5KPsykT7ZxZ6cmTB8aTs0aFgwpsP9CqqIBCw/jVUHlSFPYe+7VmxhW9PUV1XD1cQqKit+9vnbinS5rzqIizcTlW1i+IYPn7r6Rp2VrAWEBrtrc0B2WpaQzccVWbruxMR8O7cX1Nfwq/yLhEhZ9eVA9RixJ5Mhxna3BKEVFmokrtrAsNYM/3d2BZ2VrAWERrtrc0NU+2ZjBC8u3cEv7RvxneAQB/hJSZvKqoDJiSSJHjutsDUYo3kn0103a/nT3jS75PkJUh6s2N3SlTzdl8vzSzdzUtiFxwyMlpCzAqy79TejX0e6bbctbkujq5/pfp664R1XR1xtRg7OKijR//WQrH6Wk88yd7WWTNmE5zqyoYoYvthzhuY820Tu0ATNHRhJY04NDyohVZBwZMHLhMJJRO/z2B/4J+AEztdZTjTiuoxxpCnu7gz7ep2W5Aw/lTRMOjV93zXJJVV0WyRnF211vY0lyOn+8oz3P3XOjsQtieuJSSc42ioX+zlbpKSM4u8yRu/x321GeXbKJiNY3MHtUb2rV9PDX8UZMqjoyYOTCYSSn/08opfyA6cA9QAaQrJT6TGu93dljV0dVm6Ki3UFLNzUs+1x704TTf9hzzSaJpaF19TGMVFSkmfzpNhavP8xTd7Tj+d8aHFLgmePYzjaKRf7OVuspX/BN2jH+uGgjPUKCmDO6D7Wv9/CQ8jJG3KPqA+zVWu/XWucDS4ABBhzXpYzYtPDqkCpV9gzLaFprXvpsG4uSDvPk7e1ku2vv5JE95am+23GcpxZtoFtwEPPG9KGOhJTlGBFUwUDZ7WszSh6zNFdNCLqS1pqXPk1jYeJhfn9bO17oJyHlparUU0qp8UqpFKVUysmTJ91WnDf5cdcJnly4gc7N6zFvTB/qBvibXZKww21Tf1ZrKldNCLqK1pqXP0tjQeIhnvhNW/7SX0LK12mt47TWkVrryMaNG5tdjsf5356TjF+QSoemdVgwJoqgQAkpqzIiqDKBsrsChpQ8dgWrNZURmxZ2aFLb7rH7tmtgXKEUh9Srn29n/rpDjP9NWybe20lCyrtVqadE9f2y9xTjZv1MW9sBFmYNJuhvjbxus0FDlDdIZO9xR57rICMuxiYDHZRSbShupsHAEAOOW6HypvCqutafoxOCy1IOX3HvqVerIBJib+Kef/x4xb2qDk1q82hkK/pO/d6Qqb/SkJr7y0HG3dKGF6saUo4sn+LsqKm718gr7/t5D1N6ylck7s9izLxkQtUxEmq+yQ3qqp8l7/7Zcowj/evCYSSng0prXaiU+iOwiuJR2tla6zSnK6tAeVN4KYdOX7Hwa2Vr9VV1QnDyyq3XDEis3XeaofHryDiTd8XjB7MuMmHZ5svLKDmzCKfWmte+KA6psbe0YdL9nV1zJuXsqKm718jz8l8kZvSUr1h/4DRj5ibT8oZaJOS8SUN1zuySRBUYco9Ka/2V1vpGrXU7rfUbRhyzIu7enHBxUrrdx9fuO33teoE2fc1af9WpQWvNlC93MGftQUb3DWWyq0JKWJK7e8oXpB46zeg562kWFEBCbBSN1FmzSxJV5JFLKLl7c8LyjusIR2rQWvPGlzuY9fMBRt0cyku/6yIhJYQTNh4+w8jZyTSpF8Di2Gia1A0wuyThAI8MqvKm8PzK+WXu7NReecd1RFVr0Frz1tc7mVkSUi8/ICElhDM2p2czYtZ6GtapyeLYaJrWk5DyNB75zrby1tN7OCLYJZsTxkS1ZGHi4Wse79uuARsO51y5XqCfAn3lliFVrUFrzdSvdxK3Zj8jbmotISWEk7Zl5jB8VhL1a/uzODaaZkFlQsoqS2ZZeXDJIjwyqCqb2Cu7ft/DEc6vMzZlYNg1x42JasmUgWGG7RKstebt/+7iP2v2Mzy6Na8+2NXckHJ3E7uyAa28PqFwmbQjOQydmUTdgOKQuuaqhlV+sVt5cMkiPDKowP7EXkXr9xkRVqWBVVkdpY9Xldaav63axYyf9jEsuhWvDTA5pMD9TWxUA76S43wtwuPtPHaWYTOTqF3Tj8Wx0YTcUMvskoQTPPIeVXkcWb/PKrTWvPPNLv794z6GRLXitQe7mR9SQniwPcfPMTQ+ietr+LEoNppWDSWkPJ1XBZWnbXuttebv3+xm+g/7iOnTiikDunHddRJSQlTX3hPniYlPwu86xaLYKEIb2V89RngWrwoqT9v2+t3Ve/jgh70M7t2SNwZKSAnhjP0nzzMkPhGARbHRtG0s9ya9hcfeo7LHnbvrOuvdb3fz/nd7eDyyJW8+FCYhVRFHloO6mo9OSfmag6cuEBOfiK1Is2R8NO2beFBIlTe4BFX/2ffygSGvCipP2fb6vdW7+ed3e3g0IoS3BrkopCr64bciI+q116w+OiXlSw5nXSQmPpH8wiKWjL+JDk3rml2SY+y9YKoooHxwYMirggqsv+31+9/t4b3Ve3gkIoS3H+7uujMpR3/4zVbe2Y00rKhA+unikMotsLFoXDQdm3lYSIkq8ap7VFb3r+/28I9vd/NwLxeHlBA+IDM7lyEzEzmXV8DCsVF0aVHP7JKEi3jdGZVVTf9hL3//djeDwoP52yPd8ZOQEqLajubkMiQ+keyLBSSMi6JbsIWvFginyRmVG0z/YS/TVu3iofBgpj3aQ0JKCCccP5vHkPgkss7nM39MH7qH1De7JOFickblYv/+cR/TVu1iYM8WvHN1SLlyIs3RzQXt3QtyZK0xe4yYRHJ2KSerrOcmDHHiXB4x8YmcOJvH/LF9CG91g9kl2edsb8vP7RUkqFxoxk/7ePu/OxnQswV/f6zntWdSrpxIq+gYVw8hlDew4MhaY/aOawRnA1tG0L3GqfOXGBKfxLGcPOaN6UNE6wZml1Q+Z3tbfm6vIJf+XCRuzT6mfr2TB3q04O9yuU8Ip2Sdv8TQ+CQyzlxk9qje9A61cEgJwzkVVEqpR5VSaUqpIqVUpFFFebr4NTT1FREAAA60SURBVPt586ud/K57c959rAc1/OT1gKga6alrnbmQz9CZSRzMusDskb2JbtvQ7JKEmzn7G3QbMAhYY0AtXmHm//bzxlc7uL97c957vKeElHCU9FQZ2ReLQ2r/qQvMHBnJze0bmV2SMIFT96i01jsAWe27xKyfDzDlyx3cH9acf0pIiWqQnvpVTm4Bw2etZ++J88SNiODWDo3NLkmYRIYpDDL75wO8/sV27u3WjPcGVzGkjJjscXS6D5xbP0ymkSxDKTUeGA/QqlUrk6sx1tm8AkbMXs/OY2f5z/AIbu/YxOyShIkqDSql1GqgmZ1PTdJaf1rVb+TNTTV37QFe+2I7/bs24/2YcPyreiZlxGSPEdN99p5bHplGcppRPaW1jgPiACIjI7VB5Znu/KVCRs1eT1pmDv8eFsGdnZqaXZIwWaVBpbW+24hv5K1NNe+Xg7zy+Xb6dW3Kv4Y4EFLCZxnVU97owqVCRs9Zz+aMHKYP6cU9XSSkhIynO2X+uoO8/Fkav+3SlH/F9JKQEsIJF/MLGT03mQ2Hs3l/cDj9u9k76RS+yNnx9IeUUhnATcCXSqlVxpRlfQsSD/HSp2nc06UpHwzpRc0aElLCeb7aU7n5NsbOTSHl4Gn+8VgP7u/e3OyShIU4O/X3CfCJQbV4jIWJh/h/K7dxd+cmTJeQEgbyxZ7KK7AROz+FxANZ/OOxHgzoadFtehxZFkmGjgwlU38OWpR0mMkrt3FXpyZMH+pASFlhp1krNI8V/h2EZeQV2HhiQSpr951i2iM9eCg8xOySyufIskjys2woCSoHLF5/mL9+spU7OzXhw2G9uL6GX9W/2Ao7zVqheazw7yAs4VKhjT8kbOCn3Sd5++EwHomwcEgJU8k1qypasv4wL67Yyh0dG/NvR0NKCHGF/MIinkrYyPc7T/DmQ2E83tu73rIijCVBVQVLk9OZuGIrt3dszL+HRUhICeGEAlsRTy/ewOodx3l9QFeGRElIiYpJUFViaUo6f1mxhd/c2JgZwyII8JeQEqK6Cm1FPLtkI6vSjvPyA10YflOo2SUJDyD3qCqwLCWdvyzfwi3tGxE33E5IyWCAEFVWaCviuaWb+WrrMSbf35nRfduYXZJ9VV1irJRM8rmcBFU5Pk7N4IWSkIofEWn/TMqRwQBXTdxZYZLPEZ5WrzCErUjzf8s28/nmI7x4byfG3drW7JKqxxWbg4pKSVDZsWJDBhM+3kzfdhWElKNcdYblaWdunlavcFpRkeaFj7ewctMRJvTryBO3tTO7JOFh5B7VVT7ZmMHzyzZzc7uGxoWUED6qqEjz4oqtLN+QwZ/vuZGn7mhvdknCA0lQlbFyYybPL93MTW0bMnNEbwJrSkgJUV1FRZpJK7fxUUo6z9zVgWfu6mB2ScJDSVCV+HRTJn9euomoNg2ZNVJCSghnaK15+bM0Fq8/zFN3tOO5uyWkRPX5xD2qlRszmbZqF0eyc2lRP5AJ/ToyMPzX9cQ+23yE5z7aRJ82DZg1KrLqIWWFwQBPmzz0tHqFw7TWvPr5dhYkHuKJ37Tl/37b0bkdi+Vnxud5fVCt3JjJiyu2kltgAyAzO5cXV2wFYGB4MJ9vPsKflmwkMrQBs0f1plZNB/5JrNAknrYkkafVKxyitWbKlzuY+8tBxt7Shon3dnIupMD9PzMy2Wc5Xh9U01btuhxSpXILbExbtYsafoo/fbSJyNYNmONoSAkhrqC1ZurXO5n18wFG3RzK5Ps7Ox9SQuAD96iOZOfafTwzO5dnl2yiV6v6zBndm9rXS0gJUV1aa6at2sV/1uxnWHQrXn6gi4SUMIzXB1WL+oHlfi68ZX3mjO4jISWEk95dvYcPf9xHTJ9WvPZgNwkpYSivD6oJ/ToSaOe9UG0a1mbumD7UkZASwin/XL2H97/bw2ORIbwxsBvXXSchJYzl1G9ppdQ04AEgH9gHjNZaZxtRmFFKp/umrdpFZsllwNCGtfjs6b7eEVJWmDwsT3nTWvZYoV4L8ISeKmv6D3t5d/VuHu4VwtRB3V0TUlb+GTeCTDVWytnf1N8CL2qtC5VSbwMvAn9xvixjDQwPplZNP/6QsIGwkCDmj+lD3QB/s8syhpV/kCsKKZmsKo9H9BTAjJ/2MW3VLgb2bMHfHnFRSIG1f8aNIJOwlXLq0p/W+hutdWHJh4mAJbfo/Hb7cZ5atIFuwUHM86aQEl7HU3pq5v/2M/XrnTzQowXvPNoDP7ncJ1zIyHtUY4Cvy/ukUmq8UipFKZVy8uRJA79txVZvP84fElLp0iKI+WP7UE9CSngOS/bUnLUHmPLlDu4Pa867j/Wghp/X3+oWJqv00p9SajXQzM6nJmmtPy15ziSgEEgo7zha6zggDiAyMlJXq1oHfbfjOE8mpNKleT3mj5GQEtbgyT21YN1BXv18O/26NuW9wT0lpIRbVBpUWuu7K/q8UmoU8DvgLq21W5qlKr7feZwnF26gc/N6zB8bRVCghJSwBk/tqUVJh/l/n6Zxd+em/CumF/4SUsJNnPpJU0r1B14AHtRaXzSmJOf9sPMEv1+wgY7N6rJgjISUacqbyvKWaS0XsGpPLU1O56+fbOXOTk2YPjScmjUkpAwjfVIpZ6f+PgCuB74teYNfotb6905X5YQfd53giQWp3NisDgvHRhFUS0LKNN4+reUaluupj1Mz+MuKLdx2Y2M+HNqL62vIzgKGkj6plFNBpbW21C5oP+0+yfgFqXRoKiElPJPVeuqTjb/udv2f4RGykagwhdecv6/ZfZLY+Sm0b1yHhHFR1K9V0+yShPBon20+wvNLNxPdRna7FubyiqD6357ikGonISWEIb7ccpTnPtpEZKiDe7QJ4QIev4bQz3tOMW5eCm0a1SZhXBQ31DY5pGQ5FOHh/rvtKM8s2Vi8aLNsfyMswKPPqNbuPcXYecm0aVSbRbHRNDA7pECWQxEe7Zu0Y/xx0UZ6hAQxd4zsLCCswWOD6pcyIZUwLsoaISWEB/tuR/FSY12Dg2RnAWEpHhlUv+w7xZh5ybRuUBxSDetcb3ZJQni0H3ed4MmFG+jUTFZxEdbjcUG1bl8WY+Ym06pBLRJiJaSEcNb/9vz6to4FY/vIG+SF5XhUUCXuLw6pljfUYlFsNI0kpIRwyi97i4eR2jaqzcKxMjErrMljgippfxaj5yQTfEOgtUNKlkMRHiJxfxZj5iUT2tAiE7NClMMj7pauP3Ca0XOTaVE/gEWxUTSua9GQAhlBFx4h+eDpy1cn5BK6sDrLn1ElHzzNqDnraRYUwOLYaJrUDTC7JCE8WuqhM4yaXdxTCbFR1r06IUQJSwdVysHTxQ1VL4AlsdE0qSchJYQzNh4+w8jZ62lST174Cc9h2aBKPXSakbPX07ReAIvHS0gJ4awtGdmMmL2ehnVqsjg2mqbSU8JDWDKoUg+dYeTs5OJXfeOloYRw1rbMHIbNTCIo0J9FsdE0C5KeEp7DckG1oeTSRCN51SeEIbYfOcuwWUnUDfBncWw0wfUDzS5JCIdYKqi01rz+xfbiSxPj5VWfEEZ46+sdBPr7sTg2mpYNapldjhAOc2o8XSn1OjAAKAJOAKO01kecOB7/GR5BoU3TPEhe9QlhhH/FhHM2t5BWDSWkhGdy9oxqmta6u9a6J/AF8JKzBTWpG0ALuTQhfJRS6nWl1Bal1Cal1DdKqRbOHrN+rZoSUsKjORVUWuuzZT6sDWjnyhHC5xn+4k8IT+f0yhRKqTeAEUAOcIfTFQnhw+TFnxDXqvSMSim1Wim1zc6fAQBa60la65ZAAvDHCo4zXimVopRKOXnypHF/AyG8jFLqDaVUOjCUCs6opKeEr1BaG/OCTSnVCvhKa92tsudGRkbqlJQUQ76vEGZQSqVqrSOr+bWrgWZ2PjVJa/1pmee9CARorV+u7JjSU8IblNdXzk79ddBa7yn5cACw05njCeELtNZ3V/GpCcBXQKVBJYQ3c+qMSim1HOhI8Xj6IeD3WutKlw9XSp0seX55GgGnql2Y8axWD1ivJqvVA66tqbXWurHRBy374k8p9TRwm9b6kSp8naf1FFivJqmncq6uyW5fGXbpz0hKqZTqXlZxBavVA9aryWr1gDVrqkx1X/xV4biW+7ewWk1ST+XMqskj9qMSwldorR82uwYhrMZSSygJIYQQV7NqUMWZXcBVrFYPWK8mq9UD1qzJLFb8t7BaTVJP5UypyZL3qIQQQohSVj2jEkIIIQALB5VSappSamfJAp2fKKXqm1zPo0qpNKVUkVLKtEkcpVR/pdQupdRepdREs+ooU89spdQJpdQ2s2sBUEq1VEr9oJTaXvL/61mza7IK6akKa7FMX1mtp8D8vrJsUAHfAt201t2B3cCLJtezDRgErDGrAKWUHzAduBfoAsQopbqYVU+JuUB/k2soqxB4XmvdBYgGnrLAv5FVSE/ZYcG+mou1egpM7ivLBpXW+hutdWHJh4lAiMn17NBa7zKzBqAPsFdrvV9rnQ8soXhFENNordcAp82soSyt9VGt9YaS/z4H7ACCza3KGqSnymWpvrJaT4H5fWXZoLrKGOBrs4uwgGAgvczHGcgv4XIppUKBcCDJ3EosSXrqV9JXDjCjr0x9w29VFudUSk2i+LQzwQr1CM+glKoDLAf+dNXWGV5Nekq4kll9ZWpQVbY4p1JqFPA74C7thjl6BxYLNUsm0LLMxyElj4kylFL+FDdTgtZ6hdn1uJP0VLVIX1WBmX1l2Ut/Sqn+wAvAg1rri2bXYxHJQAelVBulVE1gMPCZyTVZilJKAbOAHVrrf5hdj5VIT5VL+qoSZveVZYMK+ACoC3yrlNqklJphZjFKqYeUUhnATcCXSqlV7q6h5Eb4H4FVFN/MXKq1TnN3HWUppRYD64COSqkMpdRYM+sB+gLDgTtLfm42KaXuM7kmq5CessNqfWXBngKT+0pWphBCCGFpVj6jEkIIISSohBBCWJsElRBCCEuToBJCCGFpElRCCCEsTYJKCCGEpUlQCSGEsDQJKiGEEJb2/wF5x+N/Q0V3uwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 504x216 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "tags": [],
        "id": "vGRTgxxnAsNF",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "outputId": "726c1f71-67e7-4252-be3e-cdcf1f6a78c5"
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "numpy 1.18.5\n",
            "torch 1.5.1+cu101\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}